@@ -3291,6 +3291,10 @@ def test_batched_dynamic(self, break_when_any_done):
3291
3291
)
3292
3292
del env_no_buffers
3293
3293
gc .collect ()
3294
+ # print(dummy_rollouts)
3295
+ # print(rollout_no_buffers_serial)
3296
+ # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3297
+ # assert_allclose_td(a, b)
3294
3298
assert_allclose_td (
3295
3299
dummy_rollouts .exclude ("action" ),
3296
3300
rollout_no_buffers_serial .exclude ("action" ),
@@ -3386,35 +3390,107 @@ def test_partial_rest(self, batched):
3386
3390
3387
3391
# fen strings for board positions generated with:
3388
3392
# https://lichess.org/editor
3389
- @pytest .mark .parametrize ("stateful" , [False , True ])
3390
3393
@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
3391
3394
class TestChessEnv :
3392
- def test_env (self , stateful ):
3393
- env = ChessEnv (stateful = stateful )
3394
- check_env_specs (env )
3395
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3396
+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3397
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3398
+ def test_env (self , stateful , include_pgn , include_fen ):
3399
+ with pytest .raises (
3400
+ RuntimeError , match = "At least one state representation"
3401
+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3402
+ env = ChessEnv (
3403
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3404
+ )
3405
+ check_env_specs (env )
3395
3406
3396
- def test_rollout (self , stateful ):
3397
- env = ChessEnv (stateful = stateful )
3398
- env .rollout (5000 )
3407
+ def test_pgn_bijectivity (self ):
3408
+ np .random .seed (0 )
3409
+ pgn = ChessEnv ._PGN_RESTART
3410
+ board = ChessEnv ._pgn_to_board (pgn )
3411
+ pgn_prev = pgn
3412
+ for _ in range (10 ):
3413
+ moves = list (board .legal_moves )
3414
+ move = np .random .choice (moves )
3415
+ board .push (move )
3416
+ pgn_move = ChessEnv ._board_to_pgn (board )
3417
+ assert pgn_move != pgn_prev
3418
+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3419
+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3420
+ pgn_prev = pgn_move
3421
+
3422
+ def test_consistency (self ):
3423
+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3424
+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3425
+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3426
+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3427
+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3428
+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3429
+ torch .manual_seed (0 )
3430
+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3431
+ torch .manual_seed (0 )
3432
+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3433
+ torch .manual_seed (0 )
3434
+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3435
+ torch .manual_seed (0 )
3436
+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3437
+ torch .manual_seed (0 )
3438
+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3439
+ torch .manual_seed (0 )
3440
+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3441
+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3442
+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3443
+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3444
+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3445
+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3446
+
3447
+ @pytest .mark .parametrize (
3448
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3449
+ )
3450
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3451
+ def test_rollout (self , stateful , include_pgn , include_fen ):
3452
+ torch .manual_seed (0 )
3453
+ env = ChessEnv (
3454
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3455
+ )
3456
+ r = env .rollout (500 , break_when_any_done = False )
3457
+ assert r .shape == (500 ,)
3399
3458
3400
- def test_reset_white_to_move (self , stateful ):
3401
- env = ChessEnv (stateful = stateful )
3459
+ @pytest .mark .parametrize (
3460
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3461
+ )
3462
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3463
+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3464
+ env = ChessEnv (
3465
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3466
+ )
3402
3467
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
3403
3468
td = env .reset (TensorDict ({"fen" : fen }))
3404
3469
assert td ["fen" ] == fen
3470
+ if include_fen :
3471
+ assert env .board .fen () == fen
3405
3472
assert td ["turn" ] == env .lib .WHITE
3406
3473
assert not td ["done" ]
3407
3474
3408
- def test_reset_black_to_move (self , stateful ):
3409
- env = ChessEnv (stateful = stateful )
3475
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3476
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3477
+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3478
+ env = ChessEnv (
3479
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3480
+ )
3410
3481
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
3411
3482
td = env .reset (TensorDict ({"fen" : fen }))
3412
3483
assert td ["fen" ] == fen
3484
+ assert env .board .fen () == fen
3413
3485
assert td ["turn" ] == env .lib .BLACK
3414
3486
assert not td ["done" ]
3415
3487
3416
- def test_reset_done_error (self , stateful ):
3417
- env = ChessEnv (stateful = stateful )
3488
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3489
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3490
+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3491
+ env = ChessEnv (
3492
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3493
+ )
3418
3494
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
3419
3495
with pytest .raises (ValueError ) as e_info :
3420
3496
env .reset (TensorDict ({"fen" : fen }))
@@ -3425,12 +3501,19 @@ def test_reset_done_error(self, stateful):
3425
3501
@pytest .mark .parametrize (
3426
3502
"endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
3427
3503
)
3428
- def test_reward (self , stateful , reset_without_fen , endstate ):
3504
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3505
+ @pytest .mark .parametrize ("include_fen" , [True ])
3506
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3507
+ def test_reward (
3508
+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3509
+ ):
3429
3510
if stateful and reset_without_fen :
3430
3511
# reset_without_fen is only used for stateless env
3431
3512
return
3432
3513
3433
- env = ChessEnv (stateful = stateful )
3514
+ env = ChessEnv (
3515
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3516
+ )
3434
3517
3435
3518
if endstate == "white win" :
3436
3519
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3443,28 +3526,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3443
3526
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
3444
3527
expected_turn = env .lib .BLACK
3445
3528
move = "Rg1#"
3446
- expected_reward = - 1
3529
+ expected_reward = 1
3447
3530
expected_done = True
3448
3531
3449
3532
elif endstate == "stalemate" :
3450
3533
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
3451
3534
expected_turn = env .lib .BLACK
3452
3535
move = "Rb7"
3453
- expected_reward = 0
3536
+ expected_reward = 0.5
3454
3537
expected_done = True
3455
3538
3456
3539
elif endstate == "insufficient" :
3457
3540
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
3458
3541
expected_turn = env .lib .WHITE
3459
3542
move = "Kxd4"
3460
- expected_reward = 0
3543
+ expected_reward = 0.5
3461
3544
expected_done = True
3462
3545
3463
3546
elif endstate == "50 move" :
3464
3547
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
3465
3548
expected_turn = env .lib .BLACK
3466
3549
move = "Kf7"
3467
- expected_reward = 0
3550
+ expected_reward = 0.5
3468
3551
expected_done = True
3469
3552
3470
3553
elif endstate == "not_done" :
@@ -3483,8 +3566,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3483
3566
td = env .reset (TensorDict ({"fen" : fen }))
3484
3567
assert td ["turn" ] == expected_turn
3485
3568
3486
- moves = env .get_legal_moves (None if stateful else td )
3487
- td ["action" ] = moves .index (move )
3569
+ td ["action" ] = env ._san_moves .index (move )
3488
3570
td = env .step (td )["next" ]
3489
3571
assert td ["done" ] == expected_done
3490
3572
assert td ["reward" ] == expected_reward
0 commit comments