Skip to content

Commit 2866b68

Browse files
committed
[DRAFT, Example] Add MCTS example
ghstack-source-id: bd98430 Pull Request resolved: pytorch#2796
1 parent 8c9dc05 commit 2866b68

File tree

7 files changed

+356
-30
lines changed

7 files changed

+356
-30
lines changed

examples/trees/mcts.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torchrl
8+
from tensordict import TensorDict
9+
10+
pgn_or_fen = "fen"
11+
12+
mask_actions = False
13+
14+
env = torchrl.envs.ChessEnv(
15+
include_pgn=False,
16+
include_fen=True,
17+
include_hash=True,
18+
include_hash_inv=True,
19+
include_san=True,
20+
stateful=True,
21+
mask_actions=mask_actions,
22+
)
23+
24+
25+
def transform_reward(td):
26+
if "reward" not in td:
27+
return td
28+
reward = td["reward"]
29+
if reward == 0.5:
30+
td["reward"] = 0
31+
elif reward == 1 and td["turn"]:
32+
td["reward"] = -td["reward"]
33+
return td
34+
35+
36+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
37+
# Need to transform the reward to be:
38+
# white win = 1
39+
# draw = 0
40+
# black win = -1
41+
env = env.append_transform(transform_reward)
42+
43+
forest = torchrl.data.MCTSForest()
44+
forest.reward_keys = env.reward_keys + ["_visits", "_reward_sum"]
45+
forest.done_keys = env.done_keys
46+
forest.action_keys = env.action_keys
47+
48+
if mask_actions:
49+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
50+
else:
51+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]
52+
53+
C = 2.0**0.5
54+
55+
56+
def traversal_priority_UCB1(tree, root_visits):
57+
subtree = tree.subtree
58+
td_subtree = subtree.rollout[:, -1]["next"]
59+
visits = td_subtree["_visits"]
60+
reward_sum = td_subtree["_reward_sum"].clone()
61+
62+
# If it's black's turn, flip the reward, since black wants to
63+
# optimize for the lowest reward, not highest.
64+
if not subtree.rollout[0, 0]["turn"]:
65+
reward_sum = -reward_sum
66+
67+
if tree.rollout is None:
68+
parent_visits = root_visits
69+
else:
70+
parent_visits = tree.rollout[-1]["next", "_visits"]
71+
reward_sum = reward_sum.squeeze(-1)
72+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
73+
priority[visits == 0] = float("inf")
74+
return priority
75+
76+
77+
def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits):
78+
done = False
79+
td_trees_visited = []
80+
81+
while not done:
82+
if tree.subtree is None:
83+
td_tree = tree.rollout[-1]["next"].clone()
84+
85+
if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]:
86+
actions = env.all_actions(td_tree)
87+
subtrees = []
88+
89+
for action in actions:
90+
td = env.step(env.reset(td_tree).update(action)).update(
91+
TensorDict(
92+
{
93+
("next", "_visits"): 0,
94+
("next", "_reward_sum"): env.reward_spec.zeros(),
95+
}
96+
)
97+
)
98+
99+
new_node = torchrl.data.Tree(
100+
rollout=td.unsqueeze(0),
101+
node_data=td["next"].select(*forest.node_map.in_keys),
102+
)
103+
subtrees.append(new_node)
104+
105+
# NOTE: This whole script runs about 2x faster with lazy stack
106+
# versus eager stack.
107+
tree.subtree = TensorDict.lazy_stack(subtrees)
108+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
109+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
110+
111+
else:
112+
rollout_state = td_tree
113+
114+
if rollout_state["done"]:
115+
rollout_reward = rollout_state["reward"]
116+
else:
117+
rollout = env.rollout(
118+
max_steps=max_rollout_steps,
119+
tensordict=rollout_state,
120+
)
121+
rollout_reward = rollout[-1]["next", "reward"]
122+
done = True
123+
124+
else:
125+
priorities = traversal_priority_UCB1(tree, root_visits)
126+
chosen_idx = torch.argmax(priorities).item()
127+
tree = tree.subtree[chosen_idx]
128+
td_trees_visited.append(tree.rollout[-1]["next"])
129+
130+
for td in td_trees_visited:
131+
td["_visits"] += 1
132+
td["_reward_sum"] += rollout_reward
133+
134+
135+
def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
136+
"""Performs Monte-Carlo tree search in an environment.
137+
138+
Args:
139+
forest (MCTSForest): Forest of the tree to update. If the tree does not
140+
exist yet, it is added.
141+
root (TensorDict): The root step of the tree to update.
142+
env (EnvBase): Environment to performs actions in.
143+
num_steps (int): Number of iterations to traverse.
144+
max_rollout_steps (int): Maximum number of steps for each rollout.
145+
"""
146+
if root not in forest:
147+
for action in env.all_actions(root):
148+
td = env.step(env.reset(root.clone()).update(action)).update(
149+
TensorDict(
150+
{
151+
("next", "_visits"): 0,
152+
("next", "_reward_sum"): env.reward_spec.zeros(),
153+
}
154+
)
155+
)
156+
forest.extend(td.unsqueeze(0))
157+
158+
tree = forest.get_tree(root)
159+
160+
# TODO: Add this to the root node
161+
root_visits = torch.tensor([0])
162+
163+
for _ in range(num_steps):
164+
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits)
165+
root_visits += 1
166+
167+
return tree
168+
169+
170+
def tree_format_fn(tree):
171+
td = tree.rollout[-1]["next"]
172+
return [
173+
td["san"],
174+
td[pgn_or_fen].split("\n")[-1],
175+
td["_reward_sum"].item(),
176+
td["_visits"].item(),
177+
]
178+
179+
180+
def get_best_move(fen, mcts_steps, rollout_steps):
181+
root = env.reset(TensorDict({"fen": fen}))
182+
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)
183+
184+
# print('------------------------------')
185+
# print(tree.to_string(tree_format_fn))
186+
# print('------------------------------')
187+
188+
moves = []
189+
190+
for subtree in tree.subtree:
191+
san = subtree.rollout[0]["next", "san"]
192+
reward_sum = subtree.rollout[-1]["next", "_reward_sum"]
193+
visits = subtree.rollout[-1]["next", "_visits"]
194+
value_avg = (reward_sum / visits).item()
195+
if not subtree.rollout[0]["turn"]:
196+
value_avg = -value_avg
197+
moves.append((value_avg, san))
198+
199+
moves = sorted(moves, key=lambda x: -x[0])
200+
201+
print("------------------")
202+
for value_avg, san in moves:
203+
print(f" {value_avg:0.02f} {san}")
204+
print("------------------")
205+
206+
return moves[0][1]
207+
208+
209+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
210+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
211+
assert get_best_move(fen0, 100, 10) == "Rd8#"
212+
213+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
214+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
215+
assert get_best_move(fen1, 100, 10) == "Qg6#"
216+
217+
# White has M2, best move Rxg8+. Any other move loses.
218+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
219+
assert get_best_move(fen2, 1000, 10) == "Rxg8+"

test/test_env.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4157,43 +4157,60 @@ def test_env_reset_with_hash(self, stateful, include_san):
41574157
td_check = env.reset(td.select("fen_hash"))
41584158
assert (td_check == td).all()
41594159

4160-
@pytest.mark.parametrize("include_fen", [False, True])
4161-
@pytest.mark.parametrize("include_pgn", [False, True])
4160+
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
41624161
@pytest.mark.parametrize("stateful", [False, True])
4163-
@pytest.mark.parametrize("mask_actions", [False, True])
4164-
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4165-
if not stateful and not include_fen and not include_pgn:
4166-
# pytest.skip("fen or pgn must be included if not stateful")
4167-
return
4168-
4162+
@pytest.mark.parametrize("include_hash", [False, True])
4163+
@pytest.mark.parametrize("include_san", [False, True])
4164+
@pytest.mark.parametrize("append_transform", [False, True])
4165+
#@pytest.mark.parametrize("mask_actions", [False, True])
4166+
@pytest.mark.parametrize("mask_actions", [False])
4167+
def test_all_actions(self, include_fen, include_pgn, stateful, include_hash, include_san, append_transform, mask_actions):
41694168
env = ChessEnv(
41704169
include_fen=include_fen,
41714170
include_pgn=include_pgn,
4171+
include_san=include_san,
4172+
include_hash=include_hash,
4173+
include_hash_inv=include_hash,
41724174
stateful=stateful,
41734175
mask_actions=mask_actions,
41744176
)
4175-
td = env.reset()
41764177

4177-
if not mask_actions:
4178-
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4179-
env.all_actions()
4180-
return
4178+
def transform_reward(td):
4179+
if "reward" not in td:
4180+
return td
4181+
reward = td["reward"]
4182+
if reward == 0.5:
4183+
td["reward"] = 0
4184+
elif reward == 1 and td["turn"]:
4185+
td["reward"] = -td["reward"]
4186+
return td
4187+
4188+
4189+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
4190+
# Need to transform the reward to be:
4191+
# white win = 1
4192+
# draw = 0
4193+
# black win = -1
4194+
if append_transform:
4195+
env = env.append_transform(transform_reward)
4196+
4197+
check_env_specs(env)
4198+
4199+
td = env.reset()
41814200

41824201
# Choose random actions from the output of `all_actions`
4183-
for _ in range(100):
4184-
if stateful:
4185-
all_actions = env.all_actions()
4186-
else:
4202+
for step_idx in range(100):
4203+
if step_idx % 5 == 0:
41874204
# Reset the the initial state first, just to make sure
41884205
# `all_actions` knows how to get the board state from the input.
41894206
env.reset()
4190-
all_actions = env.all_actions(td.clone())
4207+
all_actions = env.all_actions(td.clone())
41914208

41924209
# Choose some random actions and make sure they match exactly one of
41934210
# the actions from `all_actions`. This part is not tested when
41944211
# `mask_actions == False`, because `rand_action` can pick illegal
41954212
# actions in that case.
4196-
if mask_actions:
4213+
if mask_actions and step_idx % 4 == 0:
41974214
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
41984215
# it fail to work properly for stateless mode. It doesn't know
41994216
# how to correctly reset the board state to what is given in the
@@ -4210,7 +4227,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42104227

42114228
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
42124229
chosen_action = all_actions[action_idx]
4213-
td = env.step(td.update(chosen_action))["next"]
4230+
td_new = env.step(td.update(chosen_action).clone())
4231+
assert (td == td_new.exclude('next')).all()
4232+
td = td_new['next']
42144233

42154234
if td["done"]:
42164235
td = env.reset()

torchrl/data/map/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,11 @@ def valid_paths(cls, tree: Tree):
13641364
def __len__(self):
13651365
return len(self.data_map)
13661366

1367+
def __contains__(self, root: TensorDictBase):
1368+
if self.node_map is None:
1369+
return False
1370+
return root.select(*self.node_map.in_keys) in self.node_map
1371+
13671372
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
13681373
"""Generates a string representation of a tree in the forest.
13691374

0 commit comments

Comments
 (0)