Skip to content

Commit 479e66e

Browse files
committed
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san
ghstack-source-id: 87c458ebe69a21569719d12ad17fb3d7a356da0d Pull Request resolved: #2702
1 parent 3343b29 commit 479e66e

File tree

1 file changed

+210
-28
lines changed

1 file changed

+210
-28
lines changed

torchrl/envs/custom/chess.py

+210-28
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,51 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import importlib.util
8+
import io
79
from typing import Dict, Optional
810

911
import torch
12+
from PIL import Image
1013
from tensordict import TensorDict, TensorDictBase
1114
from torchrl.data import Categorical, Composite, NonTensor, Unbounded
1215

1316
from torchrl.envs import EnvBase
17+
from torchrl.envs.common import _EnvPostInit
1418

1519
from torchrl.envs.utils import _classproperty
1620

1721

18-
class ChessEnv(EnvBase):
22+
class _HashMeta(_EnvPostInit):
23+
def __call__(cls, *args, **kwargs):
24+
instance = super().__call__(*args, **kwargs)
25+
if kwargs.get("include_hash"):
26+
from torchrl.envs import Hash
27+
28+
in_keys = []
29+
out_keys = []
30+
if instance.include_san:
31+
in_keys.append("san")
32+
out_keys.append("san_hash")
33+
if instance.include_fen:
34+
in_keys.append("fen")
35+
out_keys.append("fen_hash")
36+
if instance.include_pgn:
37+
in_keys.append("pgn")
38+
out_keys.append("pgn_hash")
39+
return instance.append_transform(Hash(in_keys, out_keys))
40+
return instance
41+
42+
43+
class ChessEnv(EnvBase, metaclass=_HashMeta):
1944
"""A chess environment that follows the TorchRL API.
2045
2146
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
2247
2348
Args:
2449
stateful (bool): Whether to keep track of the internal state of the board.
2550
If False, the state will be stored in the observation and passed back
26-
to the environment on each call. Default: ``False``.
51+
to the environment on each call. Default: ``True``.
2752
2853
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
2954
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
@@ -90,28 +115,76 @@ class ChessEnv(EnvBase):
90115
"""
91116

92117
_hash_table: Dict[int, str] = {}
118+
_PNG_RESTART = """[Event "?"]
119+
[Site "?"]
120+
[Date "????.??.??"]
121+
[Round "?"]
122+
[White "?"]
123+
[Black "?"]
124+
[Result "*"]
125+
126+
*"""
93127

94128
@_classproperty
95129
def lib(cls):
96130
try:
97131
import chess
132+
import chess.pgn
98133
except ImportError:
99134
raise ImportError(
100135
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
101136
)
102137
return chess
103138

104-
def __init__(self, stateful: bool = False):
139+
def __init__(
140+
self,
141+
*,
142+
stateful: bool = True,
143+
include_san: bool = False,
144+
include_fen: bool = False,
145+
include_pgn: bool = False,
146+
include_hash: bool = False,
147+
pixels: bool = False,
148+
):
105149
chess = self.lib
106150
super().__init__()
107151
self.full_observation_spec = Composite(
108-
hashing=Unbounded(shape=(), dtype=torch.int64),
109-
fen=NonTensor(shape=()),
110152
turn=Categorical(n=2, dtype=torch.bool, shape=()),
111153
)
154+
self.include_san = include_san
155+
self.include_fen = include_fen
156+
self.include_pgn = include_pgn
157+
if include_san:
158+
self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6")
159+
if include_pgn:
160+
self.full_observation_spec["pgn"] = NonTensor(
161+
shape=(), example_data=self._PNG_RESTART
162+
)
163+
if include_fen:
164+
self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any")
165+
if not stateful and not (include_pgn or include_fen):
166+
raise RuntimeError(
167+
"At least one state representation (pgn or fen) must be enabled when stateful "
168+
f"is {stateful}."
169+
)
170+
112171
self.stateful = stateful
172+
113173
if not self.stateful:
114174
self.full_state_spec = self.full_observation_spec.clone()
175+
176+
self.pixels = pixels
177+
if pixels:
178+
if importlib.util.find_spec("cairosvg") is None:
179+
raise ImportError(
180+
"Please install cairosvg to use this environment with pixel rendering."
181+
)
182+
if importlib.util.find_spec("torchvision") is None:
183+
raise ImportError(
184+
"Please install torchvision to use this environment with pixel rendering."
185+
)
186+
self.full_observation_spec["pixels"] = Unbounded(shape=())
187+
115188
self.full_action_spec = Composite(
116189
action=Categorical(n=-1, shape=(), dtype=torch.int64)
117190
)
@@ -132,41 +205,126 @@ def _is_done(self, board):
132205

133206
def _reset(self, tensordict=None):
134207
fen = None
208+
pgn = None
135209
if tensordict is not None:
136-
fen = self._get_fen(tensordict).data
137-
dest = tensordict.empty()
210+
if self.include_fen:
211+
fen = self._get_fen(tensordict).data
212+
dest = tensordict.empty()
213+
if self.include_pgn:
214+
fen = self._get_pgn(tensordict).data
215+
dest = tensordict.empty()
138216
else:
139217
dest = TensorDict()
140218

141-
if fen is None:
219+
if fen is None and pgn is None:
142220
self.board.reset()
143-
fen = self.board.fen()
221+
if self.include_fen and fen is None:
222+
fen = self.board.fen()
223+
if self.include_pgn and pgn is None:
224+
pgn = self._PNG_RESTART
144225
else:
145-
self.board.set_fen(fen)
146-
if self._is_done(self.board):
147-
raise ValueError(
148-
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
149-
)
150-
151-
hashing = hash(fen)
226+
if fen is not None:
227+
self.board.set_fen(fen)
228+
if self._is_done(self.board):
229+
raise ValueError(
230+
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
231+
)
232+
elif pgn is not None:
233+
self.board = self._pgn_to_board(pgn)
152234

153235
self._set_action_space()
154236
turn = self.board.turn
155-
return dest.set("fen", fen).set("hashing", hashing).set("turn", turn)
237+
if self.include_san:
238+
dest.set("san", "[SAN][START]")
239+
if self.include_fen:
240+
if fen is None:
241+
fen = self.board.fen()
242+
dest.set("fen", fen)
243+
if self.include_pgn:
244+
if pgn is None:
245+
pgn = self._board_to_pgn(self.board)
246+
dest.set("pgn", pgn)
247+
dest.set("turn", turn)
248+
if self.pixels:
249+
dest.set("pixels", self._get_tensor_image(board=self.board))
250+
return dest
251+
252+
_cairosvg_lib = None
253+
254+
@_classproperty
255+
def _cairosvg(cls):
256+
csvg = cls._cairosvg_lib
257+
if csvg is None:
258+
import cairosvg
259+
260+
csvg = cls._cairosvg_lib = cairosvg
261+
return csvg
262+
263+
_torchvision_lib = None
264+
265+
@_classproperty
266+
def _torchvision(cls):
267+
tv = cls._torchvision_lib
268+
if tv is None:
269+
import torchvision
270+
271+
tv = cls._torchvision_lib = torchvision
272+
return tv
273+
274+
@classmethod
275+
def _get_tensor_image(cls, board):
276+
try:
277+
svg = board._repr_svg_()
278+
# Convert SVG to PNG using cairosvg
279+
png_data = io.BytesIO()
280+
cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data)
281+
png_data.seek(0)
282+
# Open the PNG image using Pillow
283+
img = Image.open(png_data)
284+
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
285+
except ImportError:
286+
raise ImportError(
287+
"Chess rendering requires cairosvg and torchvision to be installed."
288+
)
289+
return img
156290

157291
def _set_action_space(self, tensordict: TensorDict | None = None):
158292
if not self.stateful and tensordict is not None:
159293
fen = self._get_fen(tensordict).data
160294
self.board.set_fen(fen)
161295
self.action_spec.set_provisional_n(self.board.legal_moves.count())
162296

297+
@classmethod
298+
def _pgn_to_board(
299+
cls, pgn_string: str, board: "chess.Board" | None = None
300+
) -> "chess.Board":
301+
pgn_io = io.StringIO(pgn_string)
302+
game = cls.lib.pgn.read_game(pgn_io)
303+
if board is None:
304+
board = cls.Board()
305+
else:
306+
board.reset()
307+
for move in game.mainline_moves():
308+
board.push(move)
309+
return board
310+
311+
@classmethod
312+
def _board_to_pgn(cls, board: "chess.Board") -> str:
313+
# Create a new Game object
314+
game = cls.lib.pgn.Game()
315+
316+
# Add the moves to the game
317+
node = game
318+
for move in board.move_stack:
319+
node = node.add_variation(move)
320+
321+
# Generate the PGN string
322+
pgn_string = str(game)
323+
return pgn_string
324+
163325
@classmethod
164326
def _get_fen(cls, tensordict):
165327
fen = tensordict.get("fen", None)
166-
if fen is None:
167-
hashing = tensordict.get("hashing", None)
168-
if hashing is not None:
169-
fen = cls._hash_table.get(hashing.item())
170328
return fen
171329

172330
def get_legal_moves(self, tensordict=None, uci=False):
@@ -205,19 +363,40 @@ def _step(self, tensordict):
205363
# action
206364
action = tensordict.get("action")
207365
board = self.board
366+
208367
if not self.stateful:
209-
fen = self._get_fen(tensordict).data
210-
board.set_fen(fen)
368+
if self.include_fen:
369+
fen = self._get_fen(tensordict).data
370+
board.set_fen(fen)
371+
elif self.include_pgn:
372+
pgn = self._get_pgn(tensordict).data
373+
self._pgn_to_board(pgn, board)
374+
else:
375+
raise RuntimeError(
376+
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
377+
)
378+
211379
action = list(board.legal_moves)[action]
380+
san = None
381+
if self.include_san:
382+
san = board.san(action)
212383
board.push(action)
384+
213385
self._set_action_space()
214386

215-
# Collect data
216-
fen = self.board.fen()
217387
dest = tensordict.empty()
218-
hashing = hash(fen)
219-
dest.set("fen", fen)
220-
dest.set("hashing", hashing)
388+
389+
# Collect data
390+
if self.include_fen:
391+
fen = board.fen()
392+
dest.set("fen", fen)
393+
394+
if self.include_pgn:
395+
pgn = self._board_to_pgn(board)
396+
dest.set("pgn", pgn)
397+
398+
if san is not None:
399+
dest.set("san", san)
221400

222401
turn = torch.tensor(board.turn)
223402
if board.is_checkmate():
@@ -226,12 +405,15 @@ def _step(self, tensordict):
226405
reward_val = 1 if winner == self.lib.WHITE else -1
227406
else:
228407
reward_val = 0
408+
229409
reward = torch.tensor([reward_val], dtype=torch.int32)
230410
done = self._is_done(board)
231411
dest.set("reward", reward)
232412
dest.set("turn", turn)
233413
dest.set("done", [done])
234414
dest.set("terminated", [done])
415+
if self.pixels:
416+
dest.set("pixels", self._get_tensor_image(board=self.board))
235417
return dest
236418

237419
def _set_seed(self, *args, **kwargs):

0 commit comments

Comments
 (0)