4
4
# LICENSE file in the root directory of this source tree.
5
5
from __future__ import annotations
6
6
7
+ import importlib .util
8
+ import io
7
9
from typing import Dict , Optional
8
10
9
11
import torch
12
+ from PIL import Image
10
13
from tensordict import TensorDict , TensorDictBase
11
14
from torchrl .data import Categorical , Composite , NonTensor , Unbounded
12
15
13
16
from torchrl .envs import EnvBase
17
+ from torchrl .envs .common import _EnvPostInit
14
18
15
19
from torchrl .envs .utils import _classproperty
16
20
17
21
18
- class ChessEnv (EnvBase ):
22
+ class _HashMeta (_EnvPostInit ):
23
+ def __call__ (cls , * args , ** kwargs ):
24
+ instance = super ().__call__ (* args , ** kwargs )
25
+ if kwargs .get ("include_hash" ):
26
+ from torchrl .envs import Hash
27
+
28
+ in_keys = []
29
+ out_keys = []
30
+ if instance .include_san :
31
+ in_keys .append ("san" )
32
+ out_keys .append ("san_hash" )
33
+ if instance .include_fen :
34
+ in_keys .append ("fen" )
35
+ out_keys .append ("fen_hash" )
36
+ if instance .include_pgn :
37
+ in_keys .append ("pgn" )
38
+ out_keys .append ("pgn_hash" )
39
+ return instance .append_transform (Hash (in_keys , out_keys ))
40
+ return instance
41
+
42
+
43
+ class ChessEnv (EnvBase , metaclass = _HashMeta ):
19
44
"""A chess environment that follows the TorchRL API.
20
45
21
46
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
22
47
23
48
Args:
24
49
stateful (bool): Whether to keep track of the internal state of the board.
25
50
If False, the state will be stored in the observation and passed back
26
- to the environment on each call. Default: ``False ``.
51
+ to the environment on each call. Default: ``True ``.
27
52
28
53
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
29
54
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
@@ -90,28 +115,76 @@ class ChessEnv(EnvBase):
90
115
"""
91
116
92
117
_hash_table : Dict [int , str ] = {}
118
+ _PNG_RESTART = """[Event "?"]
119
+ [Site "?"]
120
+ [Date "????.??.??"]
121
+ [Round "?"]
122
+ [White "?"]
123
+ [Black "?"]
124
+ [Result "*"]
125
+
126
+ *"""
93
127
94
128
@_classproperty
95
129
def lib (cls ):
96
130
try :
97
131
import chess
132
+ import chess .pgn
98
133
except ImportError :
99
134
raise ImportError (
100
135
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
101
136
)
102
137
return chess
103
138
104
- def __init__ (self , stateful : bool = False ):
139
+ def __init__ (
140
+ self ,
141
+ * ,
142
+ stateful : bool = True ,
143
+ include_san : bool = False ,
144
+ include_fen : bool = False ,
145
+ include_pgn : bool = False ,
146
+ include_hash : bool = False ,
147
+ pixels : bool = False ,
148
+ ):
105
149
chess = self .lib
106
150
super ().__init__ ()
107
151
self .full_observation_spec = Composite (
108
- hashing = Unbounded (shape = (), dtype = torch .int64 ),
109
- fen = NonTensor (shape = ()),
110
152
turn = Categorical (n = 2 , dtype = torch .bool , shape = ()),
111
153
)
154
+ self .include_san = include_san
155
+ self .include_fen = include_fen
156
+ self .include_pgn = include_pgn
157
+ if include_san :
158
+ self .full_observation_spec ["san" ] = NonTensor (shape = (), example_data = "Nc6" )
159
+ if include_pgn :
160
+ self .full_observation_spec ["pgn" ] = NonTensor (
161
+ shape = (), example_data = self ._PNG_RESTART
162
+ )
163
+ if include_fen :
164
+ self .full_observation_spec ["fen" ] = NonTensor (shape = (), example_data = "any" )
165
+ if not stateful and not (include_pgn or include_fen ):
166
+ raise RuntimeError (
167
+ "At least one state representation (pgn or fen) must be enabled when stateful "
168
+ f"is { stateful } ."
169
+ )
170
+
112
171
self .stateful = stateful
172
+
113
173
if not self .stateful :
114
174
self .full_state_spec = self .full_observation_spec .clone ()
175
+
176
+ self .pixels = pixels
177
+ if pixels :
178
+ if importlib .util .find_spec ("cairosvg" ) is None :
179
+ raise ImportError (
180
+ "Please install cairosvg to use this environment with pixel rendering."
181
+ )
182
+ if importlib .util .find_spec ("torchvision" ) is None :
183
+ raise ImportError (
184
+ "Please install torchvision to use this environment with pixel rendering."
185
+ )
186
+ self .full_observation_spec ["pixels" ] = Unbounded (shape = ())
187
+
115
188
self .full_action_spec = Composite (
116
189
action = Categorical (n = - 1 , shape = (), dtype = torch .int64 )
117
190
)
@@ -132,41 +205,126 @@ def _is_done(self, board):
132
205
133
206
def _reset (self , tensordict = None ):
134
207
fen = None
208
+ pgn = None
135
209
if tensordict is not None :
136
- fen = self ._get_fen (tensordict ).data
137
- dest = tensordict .empty ()
210
+ if self .include_fen :
211
+ fen = self ._get_fen (tensordict ).data
212
+ dest = tensordict .empty ()
213
+ if self .include_pgn :
214
+ fen = self ._get_pgn (tensordict ).data
215
+ dest = tensordict .empty ()
138
216
else :
139
217
dest = TensorDict ()
140
218
141
- if fen is None :
219
+ if fen is None and pgn is None :
142
220
self .board .reset ()
143
- fen = self .board .fen ()
221
+ if self .include_fen and fen is None :
222
+ fen = self .board .fen ()
223
+ if self .include_pgn and pgn is None :
224
+ pgn = self ._PNG_RESTART
144
225
else :
145
- self .board .set_fen (fen )
146
- if self ._is_done (self .board ):
147
- raise ValueError (
148
- "Cannot reset to a fen that is a gameover state." f" fen: { fen } "
149
- )
150
-
151
- hashing = hash (fen )
226
+ if fen is not None :
227
+ self .board .set_fen (fen )
228
+ if self ._is_done (self .board ):
229
+ raise ValueError (
230
+ "Cannot reset to a fen that is a gameover state." f" fen: { fen } "
231
+ )
232
+ elif pgn is not None :
233
+ self .board = self ._pgn_to_board (pgn )
152
234
153
235
self ._set_action_space ()
154
236
turn = self .board .turn
155
- return dest .set ("fen" , fen ).set ("hashing" , hashing ).set ("turn" , turn )
237
+ if self .include_san :
238
+ dest .set ("san" , "[SAN][START]" )
239
+ if self .include_fen :
240
+ if fen is None :
241
+ fen = self .board .fen ()
242
+ dest .set ("fen" , fen )
243
+ if self .include_pgn :
244
+ if pgn is None :
245
+ pgn = self ._board_to_pgn (self .board )
246
+ dest .set ("pgn" , pgn )
247
+ dest .set ("turn" , turn )
248
+ if self .pixels :
249
+ dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
250
+ return dest
251
+
252
+ _cairosvg_lib = None
253
+
254
+ @_classproperty
255
+ def _cairosvg (cls ):
256
+ csvg = cls ._cairosvg_lib
257
+ if csvg is None :
258
+ import cairosvg
259
+
260
+ csvg = cls ._cairosvg_lib = cairosvg
261
+ return csvg
262
+
263
+ _torchvision_lib = None
264
+
265
+ @_classproperty
266
+ def _torchvision (cls ):
267
+ tv = cls ._torchvision_lib
268
+ if tv is None :
269
+ import torchvision
270
+
271
+ tv = cls ._torchvision_lib = torchvision
272
+ return tv
273
+
274
+ @classmethod
275
+ def _get_tensor_image (cls , board ):
276
+ try :
277
+ svg = board ._repr_svg_ ()
278
+ # Convert SVG to PNG using cairosvg
279
+ png_data = io .BytesIO ()
280
+ cls ._cairosvg .svg2png (bytestring = svg .encode ("utf-8" ), write_to = png_data )
281
+ png_data .seek (0 )
282
+ # Open the PNG image using Pillow
283
+ img = Image .open (png_data )
284
+ img = cls ._torchvision .transforms .functional .pil_to_tensor (img )
285
+ except ImportError :
286
+ raise ImportError (
287
+ "Chess rendering requires cairosvg and torchvision to be installed."
288
+ )
289
+ return img
156
290
157
291
def _set_action_space (self , tensordict : TensorDict | None = None ):
158
292
if not self .stateful and tensordict is not None :
159
293
fen = self ._get_fen (tensordict ).data
160
294
self .board .set_fen (fen )
161
295
self .action_spec .set_provisional_n (self .board .legal_moves .count ())
162
296
297
+ @classmethod
298
+ def _pgn_to_board (
299
+ cls , pgn_string : str , board : "chess.Board" | None = None
300
+ ) -> "chess.Board" :
301
+ pgn_io = io .StringIO (pgn_string )
302
+ game = cls .lib .pgn .read_game (pgn_io )
303
+ if board is None :
304
+ board = cls .Board ()
305
+ else :
306
+ board .reset ()
307
+ for move in game .mainline_moves ():
308
+ board .push (move )
309
+ return board
310
+
311
+ @classmethod
312
+ def _board_to_pgn (cls , board : "chess.Board" ) -> str :
313
+ # Create a new Game object
314
+ game = cls .lib .pgn .Game ()
315
+
316
+ # Add the moves to the game
317
+ node = game
318
+ for move in board .move_stack :
319
+ node = node .add_variation (move )
320
+
321
+ # Generate the PGN string
322
+ pgn_string = str (game )
323
+ return pgn_string
324
+
163
325
@classmethod
164
326
def _get_fen (cls , tensordict ):
165
327
fen = tensordict .get ("fen" , None )
166
- if fen is None :
167
- hashing = tensordict .get ("hashing" , None )
168
- if hashing is not None :
169
- fen = cls ._hash_table .get (hashing .item ())
170
328
return fen
171
329
172
330
def get_legal_moves (self , tensordict = None , uci = False ):
@@ -205,19 +363,40 @@ def _step(self, tensordict):
205
363
# action
206
364
action = tensordict .get ("action" )
207
365
board = self .board
366
+
208
367
if not self .stateful :
209
- fen = self ._get_fen (tensordict ).data
210
- board .set_fen (fen )
368
+ if self .include_fen :
369
+ fen = self ._get_fen (tensordict ).data
370
+ board .set_fen (fen )
371
+ elif self .include_pgn :
372
+ pgn = self ._get_pgn (tensordict ).data
373
+ self ._pgn_to_board (pgn , board )
374
+ else :
375
+ raise RuntimeError (
376
+ "Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
377
+ )
378
+
211
379
action = list (board .legal_moves )[action ]
380
+ san = None
381
+ if self .include_san :
382
+ san = board .san (action )
212
383
board .push (action )
384
+
213
385
self ._set_action_space ()
214
386
215
- # Collect data
216
- fen = self .board .fen ()
217
387
dest = tensordict .empty ()
218
- hashing = hash (fen )
219
- dest .set ("fen" , fen )
220
- dest .set ("hashing" , hashing )
388
+
389
+ # Collect data
390
+ if self .include_fen :
391
+ fen = board .fen ()
392
+ dest .set ("fen" , fen )
393
+
394
+ if self .include_pgn :
395
+ pgn = self ._board_to_pgn (board )
396
+ dest .set ("pgn" , pgn )
397
+
398
+ if san is not None :
399
+ dest .set ("san" , san )
221
400
222
401
turn = torch .tensor (board .turn )
223
402
if board .is_checkmate ():
@@ -226,12 +405,15 @@ def _step(self, tensordict):
226
405
reward_val = 1 if winner == self .lib .WHITE else - 1
227
406
else :
228
407
reward_val = 0
408
+
229
409
reward = torch .tensor ([reward_val ], dtype = torch .int32 )
230
410
done = self ._is_done (board )
231
411
dest .set ("reward" , reward )
232
412
dest .set ("turn" , turn )
233
413
dest .set ("done" , [done ])
234
414
dest .set ("terminated" , [done ])
415
+ if self .pixels :
416
+ dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
235
417
return dest
236
418
237
419
def _set_seed (self , * args , ** kwargs ):
0 commit comments