@@ -3291,6 +3291,10 @@ def test_batched_dynamic(self, break_when_any_done):
32913291        )
32923292        del  env_no_buffers 
32933293        gc .collect ()
3294+         # print(dummy_rollouts) 
3295+         # print(rollout_no_buffers_serial) 
3296+         # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)): 
3297+         #     assert_allclose_td(a, b) 
32943298        assert_allclose_td (
32953299            dummy_rollouts .exclude ("action" ),
32963300            rollout_no_buffers_serial .exclude ("action" ),
@@ -3386,35 +3390,107 @@ def test_partial_rest(self, batched):
33863390
33873391# fen strings for board positions generated with: 
33883392# https://lichess.org/editor 
3389- @pytest .mark .parametrize ("stateful" , [False , True ]) 
33903393@pytest .mark .skipif (not  _has_chess , reason = "chess not found" ) 
33913394class  TestChessEnv :
3392-     def  test_env (self , stateful ):
3393-         env  =  ChessEnv (stateful = stateful )
3394-         check_env_specs (env )
3395+     @pytest .mark .parametrize ("include_pgn" , [False , True ]) 
3396+     @pytest .mark .parametrize ("include_fen" , [False , True ]) 
3397+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3398+     def  test_env (self , stateful , include_pgn , include_fen ):
3399+         with  pytest .raises (
3400+             RuntimeError , match = "At least one state representation" 
3401+         ) if  not  stateful  and  not  include_pgn  and  not  include_fen  else  contextlib .nullcontext ():
3402+             env  =  ChessEnv (
3403+                 stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3404+             )
3405+             check_env_specs (env )
33953406
3396-     def  test_rollout (self , stateful ):
3397-         env  =  ChessEnv (stateful = stateful )
3398-         env .rollout (5000 )
3407+     def  test_pgn_bijectivity (self ):
3408+         np .random .seed (0 )
3409+         pgn  =  ChessEnv ._PGN_RESTART 
3410+         board  =  ChessEnv ._pgn_to_board (pgn )
3411+         pgn_prev  =  pgn 
3412+         for  _  in  range (10 ):
3413+             moves  =  list (board .legal_moves )
3414+             move  =  np .random .choice (moves )
3415+             board .push (move )
3416+             pgn_move  =  ChessEnv ._board_to_pgn (board )
3417+             assert  pgn_move  !=  pgn_prev 
3418+             assert  pgn_move  ==  ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3419+             assert  pgn_move  ==  ChessEnv ._add_move_to_pgn (pgn_prev , move )
3420+             pgn_prev  =  pgn_move 
3421+ 
3422+     def  test_consistency (self ):
3423+         env0_stateful  =  ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3424+         env1_stateful  =  ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3425+         env2_stateful  =  ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3426+         env0_stateless  =  ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3427+         env1_stateless  =  ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3428+         env2_stateless  =  ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3429+         torch .manual_seed (0 )
3430+         r1_stateless  =  env1_stateless .rollout (50 , break_when_any_done = False )
3431+         torch .manual_seed (0 )
3432+         r1_stateful  =  env1_stateful .rollout (50 , break_when_any_done = False )
3433+         torch .manual_seed (0 )
3434+         r2_stateless  =  env2_stateless .rollout (50 , break_when_any_done = False )
3435+         torch .manual_seed (0 )
3436+         r2_stateful  =  env2_stateful .rollout (50 , break_when_any_done = False )
3437+         torch .manual_seed (0 )
3438+         r0_stateless  =  env0_stateless .rollout (50 , break_when_any_done = False )
3439+         torch .manual_seed (0 )
3440+         r0_stateful  =  env0_stateful .rollout (50 , break_when_any_done = False )
3441+         assert  (r0_stateless ["action" ] ==  r1_stateless ["action" ]).all ()
3442+         assert  (r0_stateless ["action" ] ==  r2_stateless ["action" ]).all ()
3443+         assert  (r0_stateless ["action" ] ==  r0_stateful ["action" ]).all ()
3444+         assert  (r1_stateless ["action" ] ==  r1_stateful ["action" ]).all ()
3445+         assert  (r2_stateless ["action" ] ==  r2_stateful ["action" ]).all ()
3446+ 
3447+     @pytest .mark .parametrize ( 
3448+         "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]] 
3449+     ) 
3450+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3451+     def  test_rollout (self , stateful , include_pgn , include_fen ):
3452+         torch .manual_seed (0 )
3453+         env  =  ChessEnv (
3454+             stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3455+         )
3456+         r  =  env .rollout (500 , break_when_any_done = False )
3457+         assert  r .shape  ==  (500 ,)
33993458
3400-     def  test_reset_white_to_move (self , stateful ):
3401-         env  =  ChessEnv (stateful = stateful )
3459+     @pytest .mark .parametrize ( 
3460+         "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]] 
3461+     ) 
3462+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3463+     def  test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3464+         env  =  ChessEnv (
3465+             stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3466+         )
34023467        fen  =  "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1" 
34033468        td  =  env .reset (TensorDict ({"fen" : fen }))
34043469        assert  td ["fen" ] ==  fen 
3470+         if  include_fen :
3471+             assert  env .board .fen () ==  fen 
34053472        assert  td ["turn" ] ==  env .lib .WHITE 
34063473        assert  not  td ["done" ]
34073474
3408-     def  test_reset_black_to_move (self , stateful ):
3409-         env  =  ChessEnv (stateful = stateful )
3475+     @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]]) 
3476+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3477+     def  test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3478+         env  =  ChessEnv (
3479+             stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3480+         )
34103481        fen  =  "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1" 
34113482        td  =  env .reset (TensorDict ({"fen" : fen }))
34123483        assert  td ["fen" ] ==  fen 
3484+         assert  env .board .fen () ==  fen 
34133485        assert  td ["turn" ] ==  env .lib .BLACK 
34143486        assert  not  td ["done" ]
34153487
3416-     def  test_reset_done_error (self , stateful ):
3417-         env  =  ChessEnv (stateful = stateful )
3488+     @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]]) 
3489+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3490+     def  test_reset_done_error (self , stateful , include_pgn , include_fen ):
3491+         env  =  ChessEnv (
3492+             stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3493+         )
34183494        fen  =  "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1" 
34193495        with  pytest .raises (ValueError ) as  e_info :
34203496            env .reset (TensorDict ({"fen" : fen }))
@@ -3425,12 +3501,19 @@ def test_reset_done_error(self, stateful):
34253501    @pytest .mark .parametrize ( 
34263502        "endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ] 
34273503    ) 
3428-     def  test_reward (self , stateful , reset_without_fen , endstate ):
3504+     @pytest .mark .parametrize ("include_pgn" , [False , True ]) 
3505+     @pytest .mark .parametrize ("include_fen" , [True ]) 
3506+     @pytest .mark .parametrize ("stateful" , [False , True ]) 
3507+     def  test_reward (
3508+         self , stateful , reset_without_fen , endstate , include_pgn , include_fen 
3509+     ):
34293510        if  stateful  and  reset_without_fen :
34303511            # reset_without_fen is only used for stateless env 
34313512            return 
34323513
3433-         env  =  ChessEnv (stateful = stateful )
3514+         env  =  ChessEnv (
3515+             stateful = stateful , include_pgn = include_pgn , include_fen = include_fen 
3516+         )
34343517
34353518        if  endstate  ==  "white win" :
34363519            fen  =  "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1" 
@@ -3443,28 +3526,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34433526            fen  =  "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1" 
34443527            expected_turn  =  env .lib .BLACK 
34453528            move  =  "Rg1#" 
3446-             expected_reward  =  - 1 
3529+             expected_reward  =  1 
34473530            expected_done  =  True 
34483531
34493532        elif  endstate  ==  "stalemate" :
34503533            fen  =  "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1" 
34513534            expected_turn  =  env .lib .BLACK 
34523535            move  =  "Rb7" 
3453-             expected_reward  =  0 
3536+             expected_reward  =  0.5  
34543537            expected_done  =  True 
34553538
34563539        elif  endstate  ==  "insufficient" :
34573540            fen  =  "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1" 
34583541            expected_turn  =  env .lib .WHITE 
34593542            move  =  "Kxd4" 
3460-             expected_reward  =  0 
3543+             expected_reward  =  0.5  
34613544            expected_done  =  True 
34623545
34633546        elif  endstate  ==  "50 move" :
34643547            fen  =  "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123" 
34653548            expected_turn  =  env .lib .BLACK 
34663549            move  =  "Kf7" 
3467-             expected_reward  =  0 
3550+             expected_reward  =  0.5  
34683551            expected_done  =  True 
34693552
34703553        elif  endstate  ==  "not_done" :
@@ -3483,8 +3566,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34833566            td  =  env .reset (TensorDict ({"fen" : fen }))
34843567            assert  td ["turn" ] ==  expected_turn 
34853568
3486-         moves  =  env .get_legal_moves (None  if  stateful  else  td )
3487-         td ["action" ] =  moves .index (move )
3569+         td ["action" ] =  env ._san_moves .index (move )
34883570        td  =  env .step (td )["next" ]
34893571        assert  td ["done" ] ==  expected_done 
34903572        assert  td ["reward" ] ==  expected_reward 
0 commit comments