@@ -4225,43 +4225,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
4225
4225
td_check = env .reset (td .select ("fen_hash" ))
4226
4226
assert (td_check == td ).all ()
4227
4227
4228
- @pytest .mark .parametrize ("include_fen" , [False , True ])
4229
- @pytest .mark .parametrize ("include_pgn" , [False , True ])
4228
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[False , True ], [True , False ]])
4230
4229
@pytest .mark .parametrize ("stateful" , [False , True ])
4231
- @pytest .mark .parametrize ("mask_actions" , [False , True ])
4232
- def test_all_actions (self , include_fen , include_pgn , stateful , mask_actions ):
4233
- if not stateful and not include_fen and not include_pgn :
4234
- # pytest.skip("fen or pgn must be included if not stateful")
4235
- return
4236
-
4230
+ @pytest .mark .parametrize ("include_hash" , [False , True ])
4231
+ @pytest .mark .parametrize ("include_san" , [False , True ])
4232
+ @pytest .mark .parametrize ("append_transform" , [False , True ])
4233
+ @pytest .mark .parametrize ("mask_actions" , [True ])
4234
+ def test_all_actions (
4235
+ self ,
4236
+ include_fen ,
4237
+ include_pgn ,
4238
+ stateful ,
4239
+ include_hash ,
4240
+ include_san ,
4241
+ append_transform ,
4242
+ mask_actions ,
4243
+ ):
4237
4244
env = ChessEnv (
4238
4245
include_fen = include_fen ,
4239
4246
include_pgn = include_pgn ,
4247
+ include_san = include_san ,
4248
+ include_hash = include_hash ,
4249
+ include_hash_inv = include_hash ,
4240
4250
stateful = stateful ,
4241
4251
mask_actions = mask_actions ,
4242
4252
)
4243
- td = env .reset ()
4244
4253
4245
- if not mask_actions :
4246
- with pytest .raises (RuntimeError , match = "Cannot generate legal actions" ):
4247
- env .all_actions ()
4248
- return
4254
+ def transform_reward (td ):
4255
+ if "reward" not in td :
4256
+ return td
4257
+ reward = td ["reward" ]
4258
+ if reward == 0.5 :
4259
+ td ["reward" ] = 0
4260
+ elif reward == 1 and td ["turn" ]:
4261
+ td ["reward" ] = - td ["reward" ]
4262
+ return td
4263
+
4264
+ if append_transform :
4265
+ env = env .append_transform (transform_reward )
4266
+
4267
+ check_env_specs (env )
4268
+
4269
+ td = env .reset ()
4249
4270
4250
4271
# Choose random actions from the output of `all_actions`
4251
- for _ in range (100 ):
4252
- if stateful :
4253
- all_actions = env .all_actions ()
4254
- else :
4272
+ for step_idx in range (100 ):
4273
+ if step_idx % 5 == 0 :
4255
4274
# Reset the the initial state first, just to make sure
4256
4275
# `all_actions` knows how to get the board state from the input.
4257
4276
env .reset ()
4258
- all_actions = env .all_actions (td .clone ())
4277
+ all_actions = env .all_actions (td .clone ())
4259
4278
4260
4279
# Choose some random actions and make sure they match exactly one of
4261
4280
# the actions from `all_actions`. This part is not tested when
4262
4281
# `mask_actions == False`, because `rand_action` can pick illegal
4263
4282
# actions in that case.
4264
- if mask_actions :
4283
+ if mask_actions and step_idx % 4 == 0 :
4265
4284
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
4266
4285
# it fail to work properly for stateless mode. It doesn't know
4267
4286
# how to correctly reset the board state to what is given in the
@@ -4278,7 +4297,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4278
4297
4279
4298
action_idx = torch .randint (0 , all_actions .shape [0 ], ()).item ()
4280
4299
chosen_action = all_actions [action_idx ]
4281
- td = env .step (td .update (chosen_action ))["next" ]
4300
+ td_new = env .step (td .update (chosen_action ).clone ())
4301
+ assert (td == td_new .exclude ("next" )).all ()
4302
+ td = td_new ["next" ]
4282
4303
4283
4304
if td ["done" ]:
4284
4305
td = env .reset ()
0 commit comments