Skip to content

Commit 812e8ff

Browse files
committed
[Test] Improve coverage of ChessEnv.all_actions
ghstack-source-id: b7623ef Pull Request resolved: #2849
1 parent 27d3680 commit 812e8ff

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

test/test_env.py

+41-20
Original file line numberDiff line numberDiff line change
@@ -4225,43 +4225,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
42254225
td_check = env.reset(td.select("fen_hash"))
42264226
assert (td_check == td).all()
42274227

4228-
@pytest.mark.parametrize("include_fen", [False, True])
4229-
@pytest.mark.parametrize("include_pgn", [False, True])
4228+
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
42304229
@pytest.mark.parametrize("stateful", [False, True])
4231-
@pytest.mark.parametrize("mask_actions", [False, True])
4232-
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4233-
if not stateful and not include_fen and not include_pgn:
4234-
# pytest.skip("fen or pgn must be included if not stateful")
4235-
return
4236-
4230+
@pytest.mark.parametrize("include_hash", [False, True])
4231+
@pytest.mark.parametrize("include_san", [False, True])
4232+
@pytest.mark.parametrize("append_transform", [False, True])
4233+
@pytest.mark.parametrize("mask_actions", [True])
4234+
def test_all_actions(
4235+
self,
4236+
include_fen,
4237+
include_pgn,
4238+
stateful,
4239+
include_hash,
4240+
include_san,
4241+
append_transform,
4242+
mask_actions,
4243+
):
42374244
env = ChessEnv(
42384245
include_fen=include_fen,
42394246
include_pgn=include_pgn,
4247+
include_san=include_san,
4248+
include_hash=include_hash,
4249+
include_hash_inv=include_hash,
42404250
stateful=stateful,
42414251
mask_actions=mask_actions,
42424252
)
4243-
td = env.reset()
42444253

4245-
if not mask_actions:
4246-
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4247-
env.all_actions()
4248-
return
4254+
def transform_reward(td):
4255+
if "reward" not in td:
4256+
return td
4257+
reward = td["reward"]
4258+
if reward == 0.5:
4259+
td["reward"] = 0
4260+
elif reward == 1 and td["turn"]:
4261+
td["reward"] = -td["reward"]
4262+
return td
4263+
4264+
if append_transform:
4265+
env = env.append_transform(transform_reward)
4266+
4267+
check_env_specs(env)
4268+
4269+
td = env.reset()
42494270

42504271
# Choose random actions from the output of `all_actions`
4251-
for _ in range(100):
4252-
if stateful:
4253-
all_actions = env.all_actions()
4254-
else:
4272+
for step_idx in range(100):
4273+
if step_idx % 5 == 0:
42554274
# Reset the the initial state first, just to make sure
42564275
# `all_actions` knows how to get the board state from the input.
42574276
env.reset()
4258-
all_actions = env.all_actions(td.clone())
4277+
all_actions = env.all_actions(td.clone())
42594278

42604279
# Choose some random actions and make sure they match exactly one of
42614280
# the actions from `all_actions`. This part is not tested when
42624281
# `mask_actions == False`, because `rand_action` can pick illegal
42634282
# actions in that case.
4264-
if mask_actions:
4283+
if mask_actions and step_idx % 4 == 0:
42654284
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
42664285
# it fail to work properly for stateless mode. It doesn't know
42674286
# how to correctly reset the board state to what is given in the
@@ -4278,7 +4297,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42784297

42794298
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
42804299
chosen_action = all_actions[action_idx]
4281-
td = env.step(td.update(chosen_action))["next"]
4300+
td_new = env.step(td.update(chosen_action).clone())
4301+
assert (td == td_new.exclude("next")).all()
4302+
td = td_new["next"]
42824303

42834304
if td["done"]:
42844305
td = env.reset()

0 commit comments

Comments
 (0)