Skip to content

Commit e0a78e4

Browse files
committed
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san
ghstack-source-id: c3899e8 Pull Request resolved: #2702
1 parent 70ab423 commit e0a78e4

File tree

6 files changed

+29782
-102
lines changed

6 files changed

+29782
-102
lines changed

test/test_env.py

+103-21
Original file line numberDiff line numberDiff line change
@@ -3291,6 +3291,10 @@ def test_batched_dynamic(self, break_when_any_done):
32913291
)
32923292
del env_no_buffers
32933293
gc.collect()
3294+
# print(dummy_rollouts)
3295+
# print(rollout_no_buffers_serial)
3296+
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3297+
# assert_allclose_td(a, b)
32943298
assert_allclose_td(
32953299
dummy_rollouts.exclude("action"),
32963300
rollout_no_buffers_serial.exclude("action"),
@@ -3386,35 +3390,107 @@ def test_partial_rest(self, batched):
33863390

33873391
# fen strings for board positions generated with:
33883392
# https://lichess.org/editor
3389-
@pytest.mark.parametrize("stateful", [False, True])
33903393
@pytest.mark.skipif(not _has_chess, reason="chess not found")
33913394
class TestChessEnv:
3392-
def test_env(self, stateful):
3393-
env = ChessEnv(stateful=stateful)
3394-
check_env_specs(env)
3395+
@pytest.mark.parametrize("include_pgn", [False, True])
3396+
@pytest.mark.parametrize("include_fen", [False, True])
3397+
@pytest.mark.parametrize("stateful", [False, True])
3398+
def test_env(self, stateful, include_pgn, include_fen):
3399+
with pytest.raises(
3400+
RuntimeError, match="At least one state representation"
3401+
) if not stateful and not include_pgn and not include_fen else contextlib.nullcontext():
3402+
env = ChessEnv(
3403+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3404+
)
3405+
check_env_specs(env)
33953406

3396-
def test_rollout(self, stateful):
3397-
env = ChessEnv(stateful=stateful)
3398-
env.rollout(5000)
3407+
def test_pgn_bijectivity(self):
3408+
np.random.seed(0)
3409+
pgn = ChessEnv._PGN_RESTART
3410+
board = ChessEnv._pgn_to_board(pgn)
3411+
pgn_prev = pgn
3412+
for _ in range(10):
3413+
moves = list(board.legal_moves)
3414+
move = np.random.choice(moves)
3415+
board.push(move)
3416+
pgn_move = ChessEnv._board_to_pgn(board)
3417+
assert pgn_move != pgn_prev
3418+
assert pgn_move == ChessEnv._board_to_pgn(ChessEnv._pgn_to_board(pgn_move))
3419+
assert pgn_move == ChessEnv._add_move_to_pgn(pgn_prev, move)
3420+
pgn_prev = pgn_move
3421+
3422+
def test_consistency(self):
3423+
env0_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=True)
3424+
env1_stateful = ChessEnv(stateful=True, include_pgn=False, include_fen=True)
3425+
env2_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=False)
3426+
env0_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=True)
3427+
env1_stateless = ChessEnv(stateful=False, include_pgn=False, include_fen=True)
3428+
env2_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=False)
3429+
torch.manual_seed(0)
3430+
r1_stateless = env1_stateless.rollout(50, break_when_any_done=False)
3431+
torch.manual_seed(0)
3432+
r1_stateful = env1_stateful.rollout(50, break_when_any_done=False)
3433+
torch.manual_seed(0)
3434+
r2_stateless = env2_stateless.rollout(50, break_when_any_done=False)
3435+
torch.manual_seed(0)
3436+
r2_stateful = env2_stateful.rollout(50, break_when_any_done=False)
3437+
torch.manual_seed(0)
3438+
r0_stateless = env0_stateless.rollout(50, break_when_any_done=False)
3439+
torch.manual_seed(0)
3440+
r0_stateful = env0_stateful.rollout(50, break_when_any_done=False)
3441+
assert (r0_stateless["action"] == r1_stateless["action"]).all()
3442+
assert (r0_stateless["action"] == r2_stateless["action"]).all()
3443+
assert (r0_stateless["action"] == r0_stateful["action"]).all()
3444+
assert (r1_stateless["action"] == r1_stateful["action"]).all()
3445+
assert (r2_stateless["action"] == r2_stateful["action"]).all()
3446+
3447+
@pytest.mark.parametrize(
3448+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3449+
)
3450+
@pytest.mark.parametrize("stateful", [False, True])
3451+
def test_rollout(self, stateful, include_pgn, include_fen):
3452+
torch.manual_seed(0)
3453+
env = ChessEnv(
3454+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3455+
)
3456+
r = env.rollout(500, break_when_any_done=False)
3457+
assert r.shape == (500,)
33993458

3400-
def test_reset_white_to_move(self, stateful):
3401-
env = ChessEnv(stateful=stateful)
3459+
@pytest.mark.parametrize(
3460+
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
3461+
)
3462+
@pytest.mark.parametrize("stateful", [False, True])
3463+
def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
3464+
env = ChessEnv(
3465+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3466+
)
34023467
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34033468
td = env.reset(TensorDict({"fen": fen}))
34043469
assert td["fen"] == fen
3470+
if include_fen:
3471+
assert env.board.fen() == fen
34053472
assert td["turn"] == env.lib.WHITE
34063473
assert not td["done"]
34073474

3408-
def test_reset_black_to_move(self, stateful):
3409-
env = ChessEnv(stateful=stateful)
3475+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3476+
@pytest.mark.parametrize("stateful", [False, True])
3477+
def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
3478+
env = ChessEnv(
3479+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3480+
)
34103481
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34113482
td = env.reset(TensorDict({"fen": fen}))
34123483
assert td["fen"] == fen
3484+
assert env.board.fen() == fen
34133485
assert td["turn"] == env.lib.BLACK
34143486
assert not td["done"]
34153487

3416-
def test_reset_done_error(self, stateful):
3417-
env = ChessEnv(stateful=stateful)
3488+
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
3489+
@pytest.mark.parametrize("stateful", [False, True])
3490+
def test_reset_done_error(self, stateful, include_pgn, include_fen):
3491+
env = ChessEnv(
3492+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3493+
)
34183494
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34193495
with pytest.raises(ValueError) as e_info:
34203496
env.reset(TensorDict({"fen": fen}))
@@ -3425,12 +3501,19 @@ def test_reset_done_error(self, stateful):
34253501
@pytest.mark.parametrize(
34263502
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
34273503
)
3428-
def test_reward(self, stateful, reset_without_fen, endstate):
3504+
@pytest.mark.parametrize("include_pgn", [False, True])
3505+
@pytest.mark.parametrize("include_fen", [True])
3506+
@pytest.mark.parametrize("stateful", [False, True])
3507+
def test_reward(
3508+
self, stateful, reset_without_fen, endstate, include_pgn, include_fen
3509+
):
34293510
if stateful and reset_without_fen:
34303511
# reset_without_fen is only used for stateless env
34313512
return
34323513

3433-
env = ChessEnv(stateful=stateful)
3514+
env = ChessEnv(
3515+
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
3516+
)
34343517

34353518
if endstate == "white win":
34363519
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3443,28 +3526,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34433526
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34443527
expected_turn = env.lib.BLACK
34453528
move = "Rg1#"
3446-
expected_reward = -1
3529+
expected_reward = 1
34473530
expected_done = True
34483531

34493532
elif endstate == "stalemate":
34503533
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
34513534
expected_turn = env.lib.BLACK
34523535
move = "Rb7"
3453-
expected_reward = 0
3536+
expected_reward = 0.5
34543537
expected_done = True
34553538

34563539
elif endstate == "insufficient":
34573540
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
34583541
expected_turn = env.lib.WHITE
34593542
move = "Kxd4"
3460-
expected_reward = 0
3543+
expected_reward = 0.5
34613544
expected_done = True
34623545

34633546
elif endstate == "50 move":
34643547
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
34653548
expected_turn = env.lib.BLACK
34663549
move = "Kf7"
3467-
expected_reward = 0
3550+
expected_reward = 0.5
34683551
expected_done = True
34693552

34703553
elif endstate == "not_done":
@@ -3483,8 +3566,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34833566
td = env.reset(TensorDict({"fen": fen}))
34843567
assert td["turn"] == expected_turn
34853568

3486-
moves = env.get_legal_moves(None if stateful else td)
3487-
td["action"] = moves.index(move)
3569+
td["action"] = env._san_moves.index(move)
34883570
td = env.step(td)["next"]
34893571
assert td["done"] == expected_done
34903572
assert td["reward"] == expected_reward

torchrl/envs/batched_envs.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -718,17 +718,13 @@ def _create_td(self) -> None:
718718
env_output_keys = set()
719719
env_obs_keys = set()
720720
for meta_data in self.meta_data:
721-
env_obs_keys = env_obs_keys.union(
722-
key
723-
for key in meta_data.specs["output_spec"][
724-
"full_observation_spec"
725-
].keys(True, True)
726-
)
727-
env_output_keys = env_output_keys.union(
728-
meta_data.specs["output_spec"]["full_observation_spec"].keys(
729-
True, True
730-
)
721+
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
722+
True, True
731723
)
724+
keys = list(keys)
725+
env_obs_keys = env_obs_keys.union(keys)
726+
727+
env_output_keys = env_output_keys.union(keys)
732728
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
733729
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
734730
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
@@ -1003,7 +999,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1003999
for i, _env in enumerate(self._envs):
10041000
if not needs_resetting[i]:
10051001
if out_tds is not None and tensordict is not None:
1006-
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
1002+
ftd = _env.observation_spec.zero()
1003+
if self.device is None:
1004+
ftd.clear_device_()
1005+
else:
1006+
ftd = ftd.to(self.device)
1007+
out_tds[i] = ftd
10071008
continue
10081009
if tensordict is not None:
10091010
tensordict_ = tensordict[i]

torchrl/envs/common.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -2505,11 +2505,26 @@ def reset(
25052505
Returns:
25062506
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
25072507
2508+
.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
2509+
modify is :meth:`~torchrl.envs.EnvBase._reset`.
2510+
25082511
"""
25092512
if tensordict is not None:
25102513
self._assert_tensordict_shape(tensordict)
25112514

2512-
tensordict_reset = self._reset(tensordict, **kwargs)
2515+
select_reset_only = kwargs.pop("select_reset_only", False)
2516+
if select_reset_only and tensordict is not None:
2517+
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
2518+
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
2519+
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
2520+
# during the previous step).
2521+
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
2522+
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
2523+
tensordict_reset = self._reset(
2524+
tensordict.select(*self.reset_keys, strict=False), **kwargs
2525+
)
2526+
else:
2527+
tensordict_reset = self._reset(tensordict, **kwargs)
25132528
# We assume that this is done properly
25142529
# if reset.device != self.device:
25152530
# reset = reset.to(self.device, non_blocking=True)
@@ -3292,7 +3307,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
32923307
else:
32933308
any_done = False
32943309
if any_done:
3295-
tensordict._set_str(
3310+
tensordict = tensordict._set_str(
32963311
"_reset",
32973312
done.clone(),
32983313
validated=True,
@@ -3306,7 +3321,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
33063321
key="_reset",
33073322
)
33083323
if any_done:
3309-
tensordict = self.reset(tensordict)
3324+
return self.reset(tensordict, select_reset_only=True)
33103325
return tensordict
33113326

33123327
def empty_cache(self):

0 commit comments

Comments
 (0)