235
235
class TestEnvBase :
236
236
def test_run_type_checks (self ):
237
237
env = ContinuousActionVecMockEnv ()
238
+ env .adapt_dtype = False
238
239
env ._run_type_checks = False
239
240
check_env_specs (env )
240
241
env ._run_type_checks = True
@@ -4112,17 +4113,21 @@ def test_parallel_partial_steps(
4112
4113
use_buffers = use_buffers ,
4113
4114
device = device ,
4114
4115
)
4115
- td = penv .reset ()
4116
- psteps = torch .zeros (4 , dtype = torch .bool )
4117
- psteps [[1 , 3 ]] = True
4118
- td .set ("_step" , psteps )
4119
-
4120
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4121
- td = penv .step (td )
4122
- assert (td [0 ].get ("next" ) == 0 ).all ()
4123
- assert (td [1 ].get ("next" ) != 0 ).any ()
4124
- assert (td [2 ].get ("next" ) == 0 ).all ()
4125
- assert (td [3 ].get ("next" ) != 0 ).any ()
4116
+ try :
4117
+ td = penv .reset ()
4118
+ psteps = torch .zeros (4 , dtype = torch .bool )
4119
+ psteps [[1 , 3 ]] = True
4120
+ td .set ("_step" , psteps )
4121
+
4122
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4123
+ td = penv .step (td )
4124
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4125
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4126
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4127
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4128
+ finally :
4129
+ penv .close ()
4130
+ del penv
4126
4131
4127
4132
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4128
4133
def test_parallel_partial_step_and_maybe_reset (
@@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset(
4135
4140
use_buffers = use_buffers ,
4136
4141
device = device ,
4137
4142
)
4138
- td = penv .reset ()
4139
- psteps = torch .zeros (4 , dtype = torch .bool )
4140
- psteps [[1 , 3 ]] = True
4141
- td .set ("_step" , psteps )
4142
-
4143
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4144
- td , tdreset = penv .step_and_maybe_reset (td )
4145
- assert (td [0 ].get ("next" ) == 0 ).all ()
4146
- assert (td [1 ].get ("next" ) != 0 ).any ()
4147
- assert (td [2 ].get ("next" ) == 0 ).all ()
4148
- assert (td [3 ].get ("next" ) != 0 ).any ()
4143
+ try :
4144
+ td = penv .reset ()
4145
+ psteps = torch .zeros (4 , dtype = torch .bool , device = td .get ("done" ).device )
4146
+ psteps [[1 , 3 ]] = True
4147
+ td .set ("_step" , psteps )
4148
+
4149
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4150
+ td , tdreset = penv .step_and_maybe_reset (td )
4151
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4152
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4153
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4154
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4155
+ finally :
4156
+ penv .close ()
4157
+ del penv
4149
4158
4150
4159
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4151
4160
def test_serial_partial_steps (self , use_buffers , device , env_device ):
@@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
4156
4165
use_buffers = use_buffers ,
4157
4166
device = device ,
4158
4167
)
4159
- td = penv .reset ()
4160
- psteps = torch .zeros (4 , dtype = torch .bool )
4161
- psteps [[1 , 3 ]] = True
4162
- td .set ("_step" , psteps )
4163
-
4164
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4165
- td = penv .step (td )
4166
- assert (td [0 ].get ("next" ) == 0 ).all ()
4167
- assert (td [1 ].get ("next" ) != 0 ).any ()
4168
- assert (td [2 ].get ("next" ) == 0 ).all ()
4169
- assert (td [3 ].get ("next" ) != 0 ).any ()
4168
+ try :
4169
+ td = penv .reset ()
4170
+ psteps = torch .zeros (4 , dtype = torch .bool )
4171
+ psteps [[1 , 3 ]] = True
4172
+ td .set ("_step" , psteps )
4173
+
4174
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4175
+ td = penv .step (td )
4176
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4177
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4178
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4179
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4180
+ finally :
4181
+ penv .close ()
4182
+ del penv
4170
4183
4171
4184
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4172
4185
def test_serial_partial_step_and_maybe_reset (self , use_buffers , device , env_device ):
@@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
4184
4197
4185
4198
td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4186
4199
td = penv .step (td )
4187
- assert (td [0 ].get ("next" ) == 0 ). all ( )
4200
+ assert_allclose_td (td [0 ].get ("next" ), td [ 0 ], intersection = True )
4188
4201
assert (td [1 ].get ("next" ) != 0 ).any ()
4189
- assert (td [2 ].get ("next" ) == 0 ). all ( )
4202
+ assert_allclose_td (td [2 ].get ("next" ), td [ 2 ], intersection = True )
4190
4203
assert (td [3 ].get ("next" ) != 0 ).any ()
4191
4204
4192
4205
0 commit comments