Skip to content

Commit 29e955a

Browse files
committed
[Example] Self-play chess PPO example
ghstack-source-id: 4b92c30998cac7fbd578a202e6b80bf56d482fa7 Pull Request resolved: #2709
1 parent 3f7993d commit 29e955a

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

examples/agents/ppo-chess.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import tensordict.nn
6+
import torch
7+
import tqdm
8+
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9+
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
10+
from torch import nn
11+
from torch.nn.utils import clip_grad_norm_
12+
from torch.optim import Adam
13+
14+
from torchrl.collectors import SyncDataCollector
15+
16+
from torchrl.envs import ChessEnv, Tokenizer
17+
from torchrl.modules import MLP
18+
from torchrl.modules.distributions import MaskedCategorical
19+
from torchrl.objectives import ClipPPOLoss
20+
from torchrl.objectives.value import GAE
21+
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
22+
23+
tensordict.nn.set_composite_lp_aggregate(False)
24+
25+
num_epochs = 10
26+
batch_size = 256
27+
frames_per_batch = 2048
28+
29+
env = ChessEnv(include_legal_moves=True, include_fen=True)
30+
31+
# tokenize the fen - assume max 70 elements
32+
transform = Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"], max_length=70)
33+
34+
env = env.append_transform(transform)
35+
n = env.action_spec.n
36+
print(env.rollout(10000))
37+
38+
# Embedding layer for the legal moves
39+
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
40+
41+
# Embedding for the fen
42+
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
43+
44+
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
45+
46+
actor_head = nn.Linear(512, env.action_spec.n)
47+
actor_head.bias.data.fill_(0)
48+
49+
critic_head = nn.Linear(512, 1)
50+
critic_head.bias.data.fill_(0)
51+
52+
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
53+
54+
def make_mask(idx):
55+
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
56+
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
57+
58+
actor = TDProbSeq(
59+
TDMod(
60+
make_mask,
61+
in_keys=["legal_moves"], out_keys=["mask"]),
62+
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
63+
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64+
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65+
out_keys=["features"]),
66+
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
67+
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
68+
prob,
69+
)
70+
critic = TDSeq(
71+
TDMod(critic_head, in_keys=["hidden"], out_keys=["state_value"]),
72+
)
73+
74+
75+
print(env.rollout(3, actor))
76+
# loss
77+
loss = ClipPPOLoss(actor, critic)
78+
79+
optim = Adam(loss.parameters())
80+
81+
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
82+
83+
# Create a data collector
84+
collector = SyncDataCollector(
85+
create_env_fn=env,
86+
policy=actor,
87+
frames_per_batch=frames_per_batch,
88+
total_frames=1_000_000,
89+
)
90+
91+
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92+
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
93+
94+
for data in tqdm.tqdm(collector):
95+
data = data.filter_non_tensor_data()
96+
print('data', data[0::2])
97+
for i in range(num_epochs):
98+
replay_buffer0.empty()
99+
replay_buffer1.empty()
100+
with torch.no_grad():
101+
# player 0
102+
data0 = gae(data[0::2])
103+
# player 1
104+
data1 = gae(data[1::2])
105+
if i == 0:
106+
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107+
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
108+
109+
replay_buffer0.extend(data0)
110+
replay_buffer1.extend(data1)
111+
112+
n_iter = collector.frames_per_batch//(2 * batch_size)
113+
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
114+
loss_vals = (loss(d0) + loss(d1)) / 2
115+
loss_vals.sum(reduce=True).backward()
116+
gn = clip_grad_norm_(loss.parameters(), 100.0)
117+
optim.step()
118+
optim.zero_grad()

0 commit comments

Comments
 (0)