131
131
from torchrl .envs .transforms .transforms import (
132
132
AutoResetEnv ,
133
133
AutoResetTransform ,
134
+ Tokenizer ,
134
135
Transform ,
135
136
)
136
137
from torchrl .envs .utils import (
@@ -3441,35 +3442,148 @@ def test_partial_rest(self, batched):
3441
3442
3442
3443
# fen strings for board positions generated with:
3443
3444
# https://lichess.org/editor
3444
- @pytest .mark .parametrize ("stateful" , [False , True ])
3445
3445
@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
3446
3446
class TestChessEnv :
3447
- def test_env (self , stateful ):
3448
- env = ChessEnv (stateful = stateful )
3449
- check_env_specs (env )
3447
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3448
+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3449
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3450
+ @pytest .mark .parametrize ("include_hash" , [False , True ])
3451
+ @pytest .mark .parametrize ("include_san" , [False , True ])
3452
+ def test_env (self , stateful , include_pgn , include_fen , include_hash , include_san ):
3453
+ with pytest .raises (
3454
+ RuntimeError , match = "At least one state representation"
3455
+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3456
+ env = ChessEnv (
3457
+ stateful = stateful ,
3458
+ include_pgn = include_pgn ,
3459
+ include_fen = include_fen ,
3460
+ include_hash = include_hash ,
3461
+ include_san = include_san ,
3462
+ )
3463
+ # Because we always use mask_actions=True
3464
+ assert isinstance (env , TransformedEnv )
3465
+ check_env_specs (env )
3466
+ if include_hash :
3467
+ if include_fen :
3468
+ assert "fen_hash" in env .observation_spec .keys ()
3469
+ if include_pgn :
3470
+ assert "pgn_hash" in env .observation_spec .keys ()
3471
+ if include_san :
3472
+ assert "san_hash" in env .observation_spec .keys ()
3473
+
3474
+ def test_pgn_bijectivity (self ):
3475
+ np .random .seed (0 )
3476
+ pgn = ChessEnv ._PGN_RESTART
3477
+ board = ChessEnv ._pgn_to_board (pgn )
3478
+ pgn_prev = pgn
3479
+ for _ in range (10 ):
3480
+ moves = list (board .legal_moves )
3481
+ move = np .random .choice (moves )
3482
+ board .push (move )
3483
+ pgn_move = ChessEnv ._board_to_pgn (board )
3484
+ assert pgn_move != pgn_prev
3485
+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3486
+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3487
+ pgn_prev = pgn_move
3488
+
3489
+ def test_consistency (self ):
3490
+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3491
+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3492
+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3493
+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3494
+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3495
+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3496
+ torch .manual_seed (0 )
3497
+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3498
+ torch .manual_seed (0 )
3499
+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3500
+ torch .manual_seed (0 )
3501
+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3502
+ torch .manual_seed (0 )
3503
+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3504
+ torch .manual_seed (0 )
3505
+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3506
+ torch .manual_seed (0 )
3507
+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3508
+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3509
+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3510
+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3511
+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3512
+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3513
+
3514
+ @pytest .mark .parametrize (
3515
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3516
+ )
3517
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3518
+ def test_san (self , stateful , include_fen , include_pgn ):
3519
+ torch .manual_seed (0 )
3520
+ env = ChessEnv (
3521
+ stateful = stateful ,
3522
+ include_pgn = include_pgn ,
3523
+ include_fen = include_fen ,
3524
+ include_san = True ,
3525
+ )
3526
+ r = env .rollout (100 , break_when_any_done = False )
3527
+ sans = r ["next" , "san" ]
3528
+ actions = [env .san_moves .index (san ) for san in sans ]
3529
+ i = 0
3530
+
3531
+ def policy (td ):
3532
+ nonlocal i
3533
+ td ["action" ] = actions [i ]
3534
+ i += 1
3535
+ return td
3450
3536
3451
- def test_rollout (self , stateful ):
3452
- env = ChessEnv (stateful = stateful )
3453
- env .rollout (5000 )
3537
+ r2 = env .rollout (100 , policy = policy , break_when_any_done = False )
3538
+ assert_allclose_td (r , r2 )
3454
3539
3455
- def test_reset_white_to_move (self , stateful ):
3456
- env = ChessEnv (stateful = stateful )
3540
+ @pytest .mark .parametrize (
3541
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3542
+ )
3543
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3544
+ def test_rollout (self , stateful , include_pgn , include_fen ):
3545
+ torch .manual_seed (0 )
3546
+ env = ChessEnv (
3547
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3548
+ )
3549
+ r = env .rollout (500 , break_when_any_done = False )
3550
+ assert r .shape == (500 ,)
3551
+
3552
+ @pytest .mark .parametrize (
3553
+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3554
+ )
3555
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3556
+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3557
+ env = ChessEnv (
3558
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3559
+ )
3457
3560
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
3458
3561
td = env .reset (TensorDict ({"fen" : fen }))
3459
- assert td ["fen" ] == fen
3562
+ if include_fen :
3563
+ assert td ["fen" ] == fen
3564
+ assert env .board .fen () == fen
3460
3565
assert td ["turn" ] == env .lib .WHITE
3461
3566
assert not td ["done" ]
3462
3567
3463
- def test_reset_black_to_move (self , stateful ):
3464
- env = ChessEnv (stateful = stateful )
3568
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3569
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3570
+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3571
+ env = ChessEnv (
3572
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3573
+ )
3465
3574
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
3466
3575
td = env .reset (TensorDict ({"fen" : fen }))
3467
3576
assert td ["fen" ] == fen
3577
+ assert env .board .fen () == fen
3468
3578
assert td ["turn" ] == env .lib .BLACK
3469
3579
assert not td ["done" ]
3470
3580
3471
- def test_reset_done_error (self , stateful ):
3472
- env = ChessEnv (stateful = stateful )
3581
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3582
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3583
+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3584
+ env = ChessEnv (
3585
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3586
+ )
3473
3587
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
3474
3588
with pytest .raises (ValueError ) as e_info :
3475
3589
env .reset (TensorDict ({"fen" : fen }))
@@ -3480,12 +3594,19 @@ def test_reset_done_error(self, stateful):
3480
3594
@pytest .mark .parametrize (
3481
3595
"endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
3482
3596
)
3483
- def test_reward (self , stateful , reset_without_fen , endstate ):
3597
+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3598
+ @pytest .mark .parametrize ("include_fen" , [True ])
3599
+ @pytest .mark .parametrize ("stateful" , [False , True ])
3600
+ def test_reward (
3601
+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3602
+ ):
3484
3603
if stateful and reset_without_fen :
3485
3604
# reset_without_fen is only used for stateless env
3486
3605
return
3487
3606
3488
- env = ChessEnv (stateful = stateful )
3607
+ env = ChessEnv (
3608
+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3609
+ )
3489
3610
3490
3611
if endstate == "white win" :
3491
3612
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3619,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3498
3619
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
3499
3620
expected_turn = env .lib .BLACK
3500
3621
move = "Rg1#"
3501
- expected_reward = - 1
3622
+ expected_reward = 1
3502
3623
expected_done = True
3503
3624
3504
3625
elif endstate == "stalemate" :
3505
3626
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
3506
3627
expected_turn = env .lib .BLACK
3507
3628
move = "Rb7"
3508
- expected_reward = 0
3629
+ expected_reward = 0.5
3509
3630
expected_done = True
3510
3631
3511
3632
elif endstate == "insufficient" :
3512
3633
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
3513
3634
expected_turn = env .lib .WHITE
3514
3635
move = "Kxd4"
3515
- expected_reward = 0
3636
+ expected_reward = 0.5
3516
3637
expected_done = True
3517
3638
3518
3639
elif endstate == "50 move" :
3519
3640
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
3520
3641
expected_turn = env .lib .BLACK
3521
3642
move = "Kf7"
3522
- expected_reward = 0
3643
+ expected_reward = 0.5
3523
3644
expected_done = True
3524
3645
3525
3646
elif endstate == "not_done" :
@@ -3538,13 +3659,33 @@ def test_reward(self, stateful, reset_without_fen, endstate):
3538
3659
td = env .reset (TensorDict ({"fen" : fen }))
3539
3660
assert td ["turn" ] == expected_turn
3540
3661
3541
- moves = env .get_legal_moves (None if stateful else td )
3542
- td ["action" ] = moves .index (move )
3662
+ td ["action" ] = env ._san_moves .index (move )
3543
3663
td = env .step (td )["next" ]
3544
3664
assert td ["done" ] == expected_done
3545
3665
assert td ["reward" ] == expected_reward
3546
3666
assert td ["turn" ] == (not expected_turn )
3547
3667
3668
+ def test_chess_tokenized (self ):
3669
+ env = ChessEnv (include_fen = True , stateful = True , include_san = True )
3670
+ assert isinstance (env .observation_spec ["fen" ], NonTensor )
3671
+ env = env .append_transform (
3672
+ Tokenizer (in_keys = ["fen" ], out_keys = ["fen_tokenized" ])
3673
+ )
3674
+ assert isinstance (env .observation_spec ["fen" ], NonTensor )
3675
+ env .transform .transform_output_spec (env .base_env .output_spec )
3676
+ env .transform .transform_input_spec (env .base_env .input_spec )
3677
+ r = env .rollout (10 , return_contiguous = False )
3678
+ assert "fen_tokenized" in r
3679
+ assert "fen" in r
3680
+ assert "fen_tokenized" in r ["next" ]
3681
+ assert "fen" in r ["next" ]
3682
+ ftd = env .fake_tensordict ()
3683
+ assert "fen_tokenized" in ftd
3684
+ assert "fen" in ftd
3685
+ assert "fen_tokenized" in ftd ["next" ]
3686
+ assert "fen" in ftd ["next" ]
3687
+ env .check_env_specs ()
3688
+
3548
3689
3549
3690
class TestCustomEnvs :
3550
3691
def test_tictactoe_env (self ):
0 commit comments