@@ -4112,17 +4112,21 @@ def test_parallel_partial_steps(
4112
4112
use_buffers = use_buffers ,
4113
4113
device = device ,
4114
4114
)
4115
- td = penv .reset ()
4116
- psteps = torch .zeros (4 , dtype = torch .bool )
4117
- psteps [[1 , 3 ]] = True
4118
- td .set ("_step" , psteps )
4119
-
4120
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4121
- td = penv .step (td )
4122
- assert (td [0 ].get ("next" ) == 0 ).all ()
4123
- assert (td [1 ].get ("next" ) != 0 ).any ()
4124
- assert (td [2 ].get ("next" ) == 0 ).all ()
4125
- assert (td [3 ].get ("next" ) != 0 ).any ()
4115
+ try :
4116
+ td = penv .reset ()
4117
+ psteps = torch .zeros (4 , dtype = torch .bool )
4118
+ psteps [[1 , 3 ]] = True
4119
+ td .set ("_step" , psteps )
4120
+
4121
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4122
+ td = penv .step (td )
4123
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4124
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4125
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4126
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4127
+ finally :
4128
+ penv .close ()
4129
+ del penv
4126
4130
4127
4131
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4128
4132
def test_parallel_partial_step_and_maybe_reset (
@@ -4135,17 +4139,21 @@ def test_parallel_partial_step_and_maybe_reset(
4135
4139
use_buffers = use_buffers ,
4136
4140
device = device ,
4137
4141
)
4138
- td = penv .reset ()
4139
- psteps = torch .zeros (4 , dtype = torch .bool )
4140
- psteps [[1 , 3 ]] = True
4141
- td .set ("_step" , psteps )
4142
-
4143
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4144
- td , tdreset = penv .step_and_maybe_reset (td )
4145
- assert (td [0 ].get ("next" ) == 0 ).all ()
4146
- assert (td [1 ].get ("next" ) != 0 ).any ()
4147
- assert (td [2 ].get ("next" ) == 0 ).all ()
4148
- assert (td [3 ].get ("next" ) != 0 ).any ()
4142
+ try :
4143
+ td = penv .reset ()
4144
+ psteps = torch .zeros (4 , dtype = torch .bool )
4145
+ psteps [[1 , 3 ]] = True
4146
+ td .set ("_step" , psteps )
4147
+
4148
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4149
+ td , tdreset = penv .step_and_maybe_reset (td )
4150
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4151
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4152
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4153
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4154
+ finally :
4155
+ penv .close ()
4156
+ del penv
4149
4157
4150
4158
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4151
4159
def test_serial_partial_steps (self , use_buffers , device , env_device ):
@@ -4156,17 +4164,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
4156
4164
use_buffers = use_buffers ,
4157
4165
device = device ,
4158
4166
)
4159
- td = penv .reset ()
4160
- psteps = torch .zeros (4 , dtype = torch .bool )
4161
- psteps [[1 , 3 ]] = True
4162
- td .set ("_step" , psteps )
4163
-
4164
- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4165
- td = penv .step (td )
4166
- assert (td [0 ].get ("next" ) == 0 ).all ()
4167
- assert (td [1 ].get ("next" ) != 0 ).any ()
4168
- assert (td [2 ].get ("next" ) == 0 ).all ()
4169
- assert (td [3 ].get ("next" ) != 0 ).any ()
4167
+ try :
4168
+ td = penv .reset ()
4169
+ psteps = torch .zeros (4 , dtype = torch .bool )
4170
+ psteps [[1 , 3 ]] = True
4171
+ td .set ("_step" , psteps )
4172
+
4173
+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4174
+ td = penv .step (td )
4175
+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4176
+ assert (td [1 ].get ("next" ) != 0 ).any ()
4177
+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4178
+ assert (td [3 ].get ("next" ) != 0 ).any ()
4179
+ finally :
4180
+ penv .close ()
4181
+ del penv
4170
4182
4171
4183
@pytest .mark .parametrize ("use_buffers" , [False , True ])
4172
4184
def test_serial_partial_step_and_maybe_reset (self , use_buffers , device , env_device ):
@@ -4184,9 +4196,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
4184
4196
4185
4197
td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4186
4198
td = penv .step (td )
4187
- assert (td [0 ].get ("next" ) == 0 ). all ( )
4199
+ assert_allclose_td (td [0 ].get ("next" ), td [ 0 ], intersection = True )
4188
4200
assert (td [1 ].get ("next" ) != 0 ).any ()
4189
- assert (td [2 ].get ("next" ) == 0 ). all ( )
4201
+ assert_allclose_td (td [2 ].get ("next" ), td [ 2 ], intersection = True )
4190
4202
assert (td [3 ].get ("next" ) != 0 ).any ()
4191
4203
4192
4204
0 commit comments