235235class  TestEnvBase :
236236    def  test_run_type_checks (self ):
237237        env  =  ContinuousActionVecMockEnv ()
238+         env .adapt_dtype  =  False 
238239        env ._run_type_checks  =  False 
239240        check_env_specs (env )
240241        env ._run_type_checks  =  True 
@@ -4112,17 +4113,21 @@ def test_parallel_partial_steps(
41124113                use_buffers = use_buffers ,
41134114                device = device ,
41144115            )
4115-             td  =  penv .reset ()
4116-             psteps  =  torch .zeros (4 , dtype = torch .bool )
4117-             psteps [[1 , 3 ]] =  True 
4118-             td .set ("_step" , psteps )
4119- 
4120-             td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4121-             td  =  penv .step (td )
4122-             assert  (td [0 ].get ("next" ) ==  0 ).all ()
4123-             assert  (td [1 ].get ("next" ) !=  0 ).any ()
4124-             assert  (td [2 ].get ("next" ) ==  0 ).all ()
4125-             assert  (td [3 ].get ("next" ) !=  0 ).any ()
4116+             try :
4117+                 td  =  penv .reset ()
4118+                 psteps  =  torch .zeros (4 , dtype = torch .bool )
4119+                 psteps [[1 , 3 ]] =  True 
4120+                 td .set ("_step" , psteps )
4121+ 
4122+                 td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4123+                 td  =  penv .step (td )
4124+                 assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4125+                 assert  (td [1 ].get ("next" ) !=  0 ).any ()
4126+                 assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4127+                 assert  (td [3 ].get ("next" ) !=  0 ).any ()
4128+             finally :
4129+                 penv .close ()
4130+                 del  penv 
41264131
41274132    @pytest .mark .parametrize ("use_buffers" , [False , True ]) 
41284133    def  test_parallel_partial_step_and_maybe_reset (
@@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354140                use_buffers = use_buffers ,
41364141                device = device ,
41374142            )
4138-             td  =  penv .reset ()
4139-             psteps  =  torch .zeros (4 , dtype = torch .bool )
4140-             psteps [[1 , 3 ]] =  True 
4141-             td .set ("_step" , psteps )
4142- 
4143-             td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4144-             td , tdreset  =  penv .step_and_maybe_reset (td )
4145-             assert  (td [0 ].get ("next" ) ==  0 ).all ()
4146-             assert  (td [1 ].get ("next" ) !=  0 ).any ()
4147-             assert  (td [2 ].get ("next" ) ==  0 ).all ()
4148-             assert  (td [3 ].get ("next" ) !=  0 ).any ()
4143+             try :
4144+                 td  =  penv .reset ()
4145+                 psteps  =  torch .zeros (4 , dtype = torch .bool , device = td .get ("done" ).device )
4146+                 psteps [[1 , 3 ]] =  True 
4147+                 td .set ("_step" , psteps )
4148+ 
4149+                 td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4150+                 td , tdreset  =  penv .step_and_maybe_reset (td )
4151+                 assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4152+                 assert  (td [1 ].get ("next" ) !=  0 ).any ()
4153+                 assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4154+                 assert  (td [3 ].get ("next" ) !=  0 ).any ()
4155+             finally :
4156+                 penv .close ()
4157+                 del  penv 
41494158
41504159    @pytest .mark .parametrize ("use_buffers" , [False , True ]) 
41514160    def  test_serial_partial_steps (self , use_buffers , device , env_device ):
@@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564165                use_buffers = use_buffers ,
41574166                device = device ,
41584167            )
4159-             td  =  penv .reset ()
4160-             psteps  =  torch .zeros (4 , dtype = torch .bool )
4161-             psteps [[1 , 3 ]] =  True 
4162-             td .set ("_step" , psteps )
4163- 
4164-             td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4165-             td  =  penv .step (td )
4166-             assert  (td [0 ].get ("next" ) ==  0 ).all ()
4167-             assert  (td [1 ].get ("next" ) !=  0 ).any ()
4168-             assert  (td [2 ].get ("next" ) ==  0 ).all ()
4169-             assert  (td [3 ].get ("next" ) !=  0 ).any ()
4168+             try :
4169+                 td  =  penv .reset ()
4170+                 psteps  =  torch .zeros (4 , dtype = torch .bool )
4171+                 psteps [[1 , 3 ]] =  True 
4172+                 td .set ("_step" , psteps )
4173+ 
4174+                 td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4175+                 td  =  penv .step (td )
4176+                 assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4177+                 assert  (td [1 ].get ("next" ) !=  0 ).any ()
4178+                 assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4179+                 assert  (td [3 ].get ("next" ) !=  0 ).any ()
4180+             finally :
4181+                 penv .close ()
4182+                 del  penv 
41704183
41714184    @pytest .mark .parametrize ("use_buffers" , [False , True ]) 
41724185    def  test_serial_partial_step_and_maybe_reset (self , use_buffers , device , env_device ):
@@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844197
41854198            td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
41864199            td  =  penv .step (td )
4187-             assert   (td [0 ].get ("next" )  ==   0 ). all ( )
4200+             assert_allclose_td (td [0 ].get ("next" ),  td [ 0 ],  intersection = True )
41884201            assert  (td [1 ].get ("next" ) !=  0 ).any ()
4189-             assert   (td [2 ].get ("next" )  ==   0 ). all ( )
4202+             assert_allclose_td (td [2 ].get ("next" ),  td [ 2 ],  intersection = True )
41904203            assert  (td [3 ].get ("next" ) !=  0 ).any ()
41914204
41924205
0 commit comments