5
5
import tensordict .nn
6
6
import torch
7
7
import tqdm
8
- from tensordict .nn import TensorDictSequential as TDSeq , TensorDictModule as TDMod , \
9
- ProbabilisticTensorDictModule as TDProb , ProbabilisticTensorDictSequential as TDProbSeq
8
+ from tensordict .nn import (
9
+ ProbabilisticTensorDictModule as TDProb ,
10
+ ProbabilisticTensorDictSequential as TDProbSeq ,
11
+ TensorDictModule as TDMod ,
12
+ TensorDictSequential as TDSeq ,
13
+ )
10
14
from torch import nn
11
15
from torch .nn .utils import clip_grad_norm_
12
16
from torch .optim import Adam
13
17
14
18
from torchrl .collectors import SyncDataCollector
19
+ from torchrl .data import LazyTensorStorage , ReplayBuffer , SamplerWithoutReplacement
15
20
16
21
from torchrl .envs import ChessEnv , Tokenizer
17
22
from torchrl .modules import MLP
18
23
from torchrl .modules .distributions import MaskedCategorical
19
24
from torchrl .objectives import ClipPPOLoss
20
25
from torchrl .objectives .value import GAE
21
- from torchrl .data import ReplayBuffer , LazyTensorStorage , SamplerWithoutReplacement
22
26
23
27
tensordict .nn .set_composite_lp_aggregate (False )
24
28
39
43
embedding_moves = nn .Embedding (num_embeddings = n + 1 , embedding_dim = 64 )
40
44
41
45
# Embedding for the fen
42
- embedding_fen = nn .Embedding (num_embeddings = transform .tokenizer .vocab_size , embedding_dim = 64 )
46
+ embedding_fen = nn .Embedding (
47
+ num_embeddings = transform .tokenizer .vocab_size , embedding_dim = 64
48
+ )
43
49
44
50
backbone = MLP (out_features = 512 , num_cells = [512 ] * 8 , activation_class = nn .ReLU )
45
51
49
55
critic_head = nn .Linear (512 , 1 )
50
56
critic_head .bias .data .fill_ (0 )
51
57
52
- prob = TDProb (in_keys = ["logits" , "mask" ], out_keys = ["action" ], distribution_class = MaskedCategorical , return_log_prob = True )
58
+ prob = TDProb (
59
+ in_keys = ["logits" , "mask" ],
60
+ out_keys = ["action" ],
61
+ distribution_class = MaskedCategorical ,
62
+ return_log_prob = True ,
63
+ )
64
+
53
65
54
66
def make_mask (idx ):
55
67
mask = idx .new_zeros ((* idx .shape [:- 1 ], n + 1 ), dtype = torch .bool )
56
68
return mask .scatter_ (- 1 , idx , torch .ones_like (idx , dtype = torch .bool ))[..., :- 1 ]
57
69
70
+
58
71
actor = TDProbSeq (
59
- TDMod (
60
- make_mask ,
61
- in_keys = ["legal_moves" ], out_keys = ["mask" ]),
72
+ TDMod (make_mask , in_keys = ["legal_moves" ], out_keys = ["mask" ]),
62
73
TDMod (embedding_moves , in_keys = ["legal_moves" ], out_keys = ["embedded_legal_moves" ]),
63
74
TDMod (embedding_fen , in_keys = ["fen_tokenized" ], out_keys = ["embedded_fen" ]),
64
- TDMod (lambda * args : torch .cat ([arg .view (* arg .shape [:- 2 ], - 1 ) for arg in args ], dim = - 1 ), in_keys = ["embedded_legal_moves" , "embedded_fen" ],
65
- out_keys = ["features" ]),
75
+ TDMod (
76
+ lambda * args : torch .cat (
77
+ [arg .view (* arg .shape [:- 2 ], - 1 ) for arg in args ], dim = - 1
78
+ ),
79
+ in_keys = ["embedded_legal_moves" , "embedded_fen" ],
80
+ out_keys = ["features" ],
81
+ ),
66
82
TDMod (backbone , in_keys = ["features" ], out_keys = ["hidden" ]),
67
83
TDMod (actor_head , in_keys = ["hidden" ], out_keys = ["logits" ]),
68
84
prob ,
@@ -78,7 +94,9 @@ def make_mask(idx):
78
94
79
95
optim = Adam (loss .parameters ())
80
96
81
- gae = GAE (value_network = TDSeq (* actor [:- 2 ], critic ), gamma = 0.99 , lmbda = 0.95 , shifted = True )
97
+ gae = GAE (
98
+ value_network = TDSeq (* actor [:- 2 ], critic ), gamma = 0.99 , lmbda = 0.95 , shifted = True
99
+ )
82
100
83
101
# Create a data collector
84
102
collector = SyncDataCollector (
@@ -88,12 +106,20 @@ def make_mask(idx):
88
106
total_frames = 1_000_000 ,
89
107
)
90
108
91
- replay_buffer0 = ReplayBuffer (storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ), batch_size = batch_size , sampler = SamplerWithoutReplacement ())
92
- replay_buffer1 = ReplayBuffer (storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ), batch_size = batch_size , sampler = SamplerWithoutReplacement ())
109
+ replay_buffer0 = ReplayBuffer (
110
+ storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ),
111
+ batch_size = batch_size ,
112
+ sampler = SamplerWithoutReplacement (),
113
+ )
114
+ replay_buffer1 = ReplayBuffer (
115
+ storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ),
116
+ batch_size = batch_size ,
117
+ sampler = SamplerWithoutReplacement (),
118
+ )
93
119
94
120
for data in tqdm .tqdm (collector ):
95
121
data = data .filter_non_tensor_data ()
96
- print (' data' , data [0 ::2 ])
122
+ print (" data" , data [0 ::2 ])
97
123
for i in range (num_epochs ):
98
124
replay_buffer0 .empty ()
99
125
replay_buffer1 .empty ()
@@ -103,14 +129,24 @@ def make_mask(idx):
103
129
# player 1
104
130
data1 = gae (data [1 ::2 ])
105
131
if i == 0 :
106
- print ('win rate for 0' , data0 ["next" , "reward" ].sum () / data ["next" , "done" ].sum ().clamp_min (1e-6 ))
107
- print ('win rate for 1' , data1 ["next" , "reward" ].sum () / data ["next" , "done" ].sum ().clamp_min (1e-6 ))
132
+ print (
133
+ "win rate for 0" ,
134
+ data0 ["next" , "reward" ].sum ()
135
+ / data ["next" , "done" ].sum ().clamp_min (1e-6 ),
136
+ )
137
+ print (
138
+ "win rate for 1" ,
139
+ data1 ["next" , "reward" ].sum ()
140
+ / data ["next" , "done" ].sum ().clamp_min (1e-6 ),
141
+ )
108
142
109
143
replay_buffer0 .extend (data0 )
110
144
replay_buffer1 .extend (data1 )
111
145
112
- n_iter = collector .frames_per_batch // (2 * batch_size )
113
- for (d0 , d1 ) in tqdm .tqdm (zip (replay_buffer0 , replay_buffer1 , strict = True ), total = n_iter ):
146
+ n_iter = collector .frames_per_batch // (2 * batch_size )
147
+ for (d0 , d1 ) in tqdm .tqdm (
148
+ zip (replay_buffer0 , replay_buffer1 , strict = True ), total = n_iter
149
+ ):
114
150
loss_vals = (loss (d0 ) + loss (d1 )) / 2
115
151
loss_vals .sum (reduce = True ).backward ()
116
152
gn = clip_grad_norm_ (loss .parameters (), 100.0 )
0 commit comments