@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269
269
batch_size = (
270
270
tensordict .batch_size if tensordict is not None else self .batch_size
271
271
)
272
- if tensordict is None or tensordict . is_empty () :
272
+ if tensordict is None or "params" not in tensordict :
273
273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274
274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275
275
# parameters to get started.
276
276
tensordict = self .gen_params (batch_size = batch_size , device = self .device )
277
+ elif "th" in tensordict and "thdot" in tensordict :
278
+ # we can hard-reset the env too
279
+ return tensordict
280
+ out = self ._reset_random_data (
281
+ tensordict .shape , batch_size , tensordict ["params" ]
282
+ )
283
+ return out
284
+
285
+ def _reset_random_data (self , shape , batch_size , params ):
277
286
278
287
high_th = torch .tensor (self .DEFAULT_X , device = self .device )
279
288
high_thdot = torch .tensor (self .DEFAULT_Y , device = self .device )
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284
293
# of simulators run simultaneously. In other contexts, the initial
285
294
# random state's shape will depend upon the environment batch-size instead.
286
295
th = (
287
- torch .rand (tensordict . shape , generator = self .rng , device = self .device )
296
+ torch .rand (shape , generator = self .rng , device = self .device )
288
297
* (high_th - low_th )
289
298
+ low_th
290
299
)
291
300
thdot = (
292
- torch .rand (tensordict . shape , generator = self .rng , device = self .device )
301
+ torch .rand (shape , generator = self .rng , device = self .device )
293
302
* (high_thdot - low_thdot )
294
303
+ low_thdot
295
304
)
296
305
out = TensorDict (
297
306
{
298
307
"th" : th ,
299
308
"thdot" : thdot ,
300
- "params" : tensordict [ " params" ] ,
309
+ "params" : params ,
301
310
},
302
311
batch_size = batch_size ,
303
312
)
0 commit comments