Skip to content

Commit d425777

Browse files
committed
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san, action mask
ghstack-source-id: f294a2bc99a17911c9b62558d530b148d3c0350f Pull Request resolved: #2702
1 parent 093a159 commit d425777

File tree

8 files changed

+29947
-118
lines changed

8 files changed

+29947
-118
lines changed

test/test_env.py

+163-22
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
from torchrl.envs.transforms.transforms import (
132132
AutoResetEnv,
133133
AutoResetTransform,
134+
Tokenizer,
134135
Transform,
135136
)
136137
from torchrl.envs.utils import (
@@ -3441,35 +3442,148 @@ def test_partial_rest(self, batched):
34413442

34423443
# fen strings for board positions generated with:
34433444
# https://lichess.org/editor
3444-
@pytest.mark.parametrize("stateful", [False, True])
34453445
@pytest.mark.skipif(not _has_chess, reason="chess not found")
34463446
class TestChessEnv:
3447-
def test_env(self, stateful):
3448-
env = ChessEnv(stateful=stateful)
3449-
check_env_specs(env)
3447+
@pytest.mark.parametrize("include_pgn", [False, True])
3448+
@pytest.mark.parametrize("include_fen", [False, True])
3449+
@pytest.mark.parametrize("stateful", [False, True])
3450+
@pytest.mark.parametrize("include_hash", [False, True])
3451+
@pytest.mark.parametrize("include_san", [False, True])
3452+
def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san):
3453+
with pytest.raises(
3454+
RuntimeError, match="At least one state representation"
3455+
) if not stateful and not include_pgn and not include_fen else contextlib.nullcontext():
3456+
env = ChessEnv(
3457+
stateful=stateful,
3458+
include_pgn=include_pgn,
3459+
include_fen=include_fen,
3460+
include_hash=include_hash,
3461+
include_san=include_san,
3462+
)
3463+
# Because we always use mask_actions=True
3464+
assert isinstance(env, TransformedEnv)
3465+
check_env_specs(env)
3466+
if include_hash:
3467+
if include_fen:
3468+
assert "fen_hash" in env.observation_spec.keys()
3469+
if include_pgn:
3470+
assert "pgn_hash" in env.observation_spec.keys()
3471+
if include_san:
3472+
assert "san_hash" in env.observation_spec.keys()
3473+
3474+
def test_pgn_bijectivity(self):
3475+
np.random.seed(0)
3476+
pgn = ChessEnv._PGN_RESTART
3477+
board = ChessEnv._pgn_to_board(pgn)
3478+
pgn_prev = pgn
3479+
for _ in range(10):
3480+
moves = list(board.legal_moves)
3481+
move = np.random.choice(moves)
3482+
board.push(move)
3483+
pgn_move = ChessEnv._board_to_pgn(board)
3484+
assert pgn_move != pgn_prev
3485+
assert pgn_move == ChessEnv._board_to_pgn(ChessEnv._pgn_to_board(pgn_move))
3486+
assert pgn_move == ChessEnv._add_move_to_pgn(pgn_prev, move)
3487+
pgn_prev = pgn_move
3488+
3489+
def test_consistency(self):
3490+
env0_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=True)
3491+
env1_stateful = ChessEnv(stateful=True, include_pgn=False, include_fen=True)
3492+
env2_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=False)
3493+
env0_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=True)
3494+
env1_stateless = ChessEnv(stateful=False, include_pgn=False, include_fen=True)
3495+
env2_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=False)
3496+
torch.manual_seed(0)
3497+
r1_stateless = env1_stateless.rollout(50, break_when_any_done=False)
3498+
torch.manual_seed(0)
3499+
r1_stateful = env1_stateful.rollout(50, break_when_any_done=False)
3500+
torch.manual_seed(0)
3501+
r2_stateless = env2_stateless.rollout(50, break_when_any_done=False)
3502+
torch.manual_seed(0)
3503+
r2_stateful = env2_stateful.rollout(50, break_when_any_done=False)
3504+
torch.manual_seed(0)
3505+
r0_stateless = env0_stateless.rollout(50, break_when_any_done=False)
3506+
torch.manual_seed(0)
3507+
r0_stateful = env0_stateful.rollout(50, break_when_any_done=False)
3508+
assert (r0_stateless["action"] == r1_stateless["action"]).all()
3509+
assert (r0_stateless["action"] == r2_stateless["action"]).all()
3510+
assert (r0_stateless["action"] == r0_stateful["action"]).all()
3511+
assert (r1_stateless["action"] == r1_stateful["action"]).all()
3512+
assert (r2_stateless["action"] == r2_stateful["action"]).all()
3513+
3514+
@pytest.mark.parametrize(
3515+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3516+
)
3517+
@pytest.mark.parametrize("stateful", [False, True])
3518+
def test_san(self, stateful, include_fen, include_pgn):
3519+
torch.manual_seed(0)
3520+
env = ChessEnv(
3521+
stateful=stateful,
3522+
include_pgn=include_pgn,
3523+
include_fen=include_fen,
3524+
include_san=True,
3525+
)
3526+
r = env.rollout(100, break_when_any_done=False)
3527+
sans = r["next", "san"]
3528+
actions = [env.san_moves.index(san) for san in sans]
3529+
i = 0
3530+
3531+
def policy(td):
3532+
nonlocal i
3533+
td["action"] = actions[i]
3534+
i += 1
3535+
return td
34503536

3451-
def test_rollout(self, stateful):
3452-
env = ChessEnv(stateful=stateful)
3453-
env.rollout(5000)
3537+
r2 = env.rollout(100, policy=policy, break_when_any_done=False)
3538+
assert_allclose_td(r, r2)
34543539

3455-
def test_reset_white_to_move(self, stateful):
3456-
env = ChessEnv(stateful=stateful)
3540+
@pytest.mark.parametrize(
3541+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3542+
)
3543+
@pytest.mark.parametrize("stateful", [False, True])
3544+
def test_rollout(self, stateful, include_pgn, include_fen):
3545+
torch.manual_seed(0)
3546+
env = ChessEnv(
3547+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3548+
)
3549+
r = env.rollout(500, break_when_any_done=False)
3550+
assert r.shape == (500,)
3551+
3552+
@pytest.mark.parametrize(
3553+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3554+
)
3555+
@pytest.mark.parametrize("stateful", [False, True])
3556+
def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
3557+
env = ChessEnv(
3558+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3559+
)
34573560
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34583561
td = env.reset(TensorDict({"fen": fen}))
3459-
assert td["fen"] == fen
3562+
if include_fen:
3563+
assert td["fen"] == fen
3564+
assert env.board.fen() == fen
34603565
assert td["turn"] == env.lib.WHITE
34613566
assert not td["done"]
34623567

3463-
def test_reset_black_to_move(self, stateful):
3464-
env = ChessEnv(stateful=stateful)
3568+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3569+
@pytest.mark.parametrize("stateful", [False, True])
3570+
def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
3571+
env = ChessEnv(
3572+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3573+
)
34653574
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34663575
td = env.reset(TensorDict({"fen": fen}))
34673576
assert td["fen"] == fen
3577+
assert env.board.fen() == fen
34683578
assert td["turn"] == env.lib.BLACK
34693579
assert not td["done"]
34703580

3471-
def test_reset_done_error(self, stateful):
3472-
env = ChessEnv(stateful=stateful)
3581+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3582+
@pytest.mark.parametrize("stateful", [False, True])
3583+
def test_reset_done_error(self, stateful, include_pgn, include_fen):
3584+
env = ChessEnv(
3585+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3586+
)
34733587
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34743588
with pytest.raises(ValueError) as e_info:
34753589
env.reset(TensorDict({"fen": fen}))
@@ -3480,12 +3594,19 @@ def test_reset_done_error(self, stateful):
34803594
@pytest.mark.parametrize(
34813595
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
34823596
)
3483-
def test_reward(self, stateful, reset_without_fen, endstate):
3597+
@pytest.mark.parametrize("include_pgn", [False, True])
3598+
@pytest.mark.parametrize("include_fen", [True])
3599+
@pytest.mark.parametrize("stateful", [False, True])
3600+
def test_reward(
3601+
self, stateful, reset_without_fen, endstate, include_pgn, include_fen
3602+
):
34843603
if stateful and reset_without_fen:
34853604
# reset_without_fen is only used for stateless env
34863605
return
34873606

3488-
env = ChessEnv(stateful=stateful)
3607+
env = ChessEnv(
3608+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3609+
)
34893610

34903611
if endstate == "white win":
34913612
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3619,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34983619
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34993620
expected_turn = env.lib.BLACK
35003621
move = "Rg1#"
3501-
expected_reward = -1
3622+
expected_reward = 1
35023623
expected_done = True
35033624

35043625
elif endstate == "stalemate":
35053626
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
35063627
expected_turn = env.lib.BLACK
35073628
move = "Rb7"
3508-
expected_reward = 0
3629+
expected_reward = 0.5
35093630
expected_done = True
35103631

35113632
elif endstate == "insufficient":
35123633
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
35133634
expected_turn = env.lib.WHITE
35143635
move = "Kxd4"
3515-
expected_reward = 0
3636+
expected_reward = 0.5
35163637
expected_done = True
35173638

35183639
elif endstate == "50 move":
35193640
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
35203641
expected_turn = env.lib.BLACK
35213642
move = "Kf7"
3522-
expected_reward = 0
3643+
expected_reward = 0.5
35233644
expected_done = True
35243645

35253646
elif endstate == "not_done":
@@ -3538,13 +3659,33 @@ def test_reward(self, stateful, reset_without_fen, endstate):
35383659
td = env.reset(TensorDict({"fen": fen}))
35393660
assert td["turn"] == expected_turn
35403661

3541-
moves = env.get_legal_moves(None if stateful else td)
3542-
td["action"] = moves.index(move)
3662+
td["action"] = env._san_moves.index(move)
35433663
td = env.step(td)["next"]
35443664
assert td["done"] == expected_done
35453665
assert td["reward"] == expected_reward
35463666
assert td["turn"] == (not expected_turn)
35473667

3668+
def test_chess_tokenized(self):
3669+
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
3670+
assert isinstance(env.observation_spec["fen"], NonTensor)
3671+
env = env.append_transform(
3672+
Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"])
3673+
)
3674+
assert isinstance(env.observation_spec["fen"], NonTensor)
3675+
env.transform.transform_output_spec(env.base_env.output_spec)
3676+
env.transform.transform_input_spec(env.base_env.input_spec)
3677+
r = env.rollout(10, return_contiguous=False)
3678+
assert "fen_tokenized" in r
3679+
assert "fen" in r
3680+
assert "fen_tokenized" in r["next"]
3681+
assert "fen" in r["next"]
3682+
ftd = env.fake_tensordict()
3683+
assert "fen_tokenized" in ftd
3684+
assert "fen" in ftd
3685+
assert "fen_tokenized" in ftd["next"]
3686+
assert "fen" in ftd["next"]
3687+
env.check_env_specs()
3688+
35483689

35493690
class TestCustomEnvs:
35503691
def test_tictactoe_env(self):

torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
50425042

50435043
def __eq__(self, other):
50445044
return (
5045-
type(self) is type(other)
5045+
type(self) == type(other)
50465046
and self.shape == other.shape
50475047
and self._device == other._device
50485048
and set(self._specs.keys()) == set(other._specs.keys())

torchrl/envs/batched_envs.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -718,17 +718,13 @@ def _create_td(self) -> None:
718718
env_output_keys = set()
719719
env_obs_keys = set()
720720
for meta_data in self.meta_data:
721-
env_obs_keys = env_obs_keys.union(
722-
key
723-
for key in meta_data.specs["output_spec"][
724-
"full_observation_spec"
725-
].keys(True, True)
726-
)
727-
env_output_keys = env_output_keys.union(
728-
meta_data.specs["output_spec"]["full_observation_spec"].keys(
729-
True, True
730-
)
721+
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
722+
True, True
731723
)
724+
keys = list(keys)
725+
env_obs_keys = env_obs_keys.union(keys)
726+
727+
env_output_keys = env_output_keys.union(keys)
732728
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
733729
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
734730
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
@@ -1003,7 +999,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1003999
for i, _env in enumerate(self._envs):
10041000
if not needs_resetting[i]:
10051001
if out_tds is not None and tensordict is not None:
1006-
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
1002+
ftd = _env.observation_spec.zero()
1003+
if self.device is None:
1004+
ftd.clear_device_()
1005+
else:
1006+
ftd = ftd.to(self.device)
1007+
out_tds[i] = ftd
10071008
continue
10081009
if tensordict is not None:
10091010
tensordict_ = tensordict[i]

torchrl/envs/common.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -2505,11 +2505,26 @@ def reset(
25052505
Returns:
25062506
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
25072507
2508+
.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
2509+
modify is :meth:`~torchrl.envs.EnvBase._reset`.
2510+
25082511
"""
25092512
if tensordict is not None:
25102513
self._assert_tensordict_shape(tensordict)
25112514

2512-
tensordict_reset = self._reset(tensordict, **kwargs)
2515+
select_reset_only = kwargs.pop("select_reset_only", False)
2516+
if select_reset_only and tensordict is not None:
2517+
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
2518+
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
2519+
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
2520+
# during the previous step).
2521+
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
2522+
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
2523+
tensordict_reset = self._reset(
2524+
tensordict.select(*self.reset_keys, strict=False), **kwargs
2525+
)
2526+
else:
2527+
tensordict_reset = self._reset(tensordict, **kwargs)
25132528
# We assume that this is done properly
25142529
# if reset.device != self.device:
25152530
# reset = reset.to(self.device, non_blocking=True)
@@ -3293,7 +3308,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
32933308
else:
32943309
any_done = False
32953310
if any_done:
3296-
tensordict._set_str(
3311+
tensordict = tensordict._set_str(
32973312
"_reset",
32983313
done.clone(),
32993314
validated=True,
@@ -3307,7 +3322,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
33073322
key="_reset",
33083323
)
33093324
if any_done:
3310-
tensordict = self.reset(tensordict)
3325+
return self.reset(tensordict, select_reset_only=True)
33113326
return tensordict
33123327

33133328
def empty_cache(self):

0 commit comments

Comments
 (0)