@@ -250,109 +250,200 @@ def test_env_seed(env_name, frame_skip, seed=0):
250
250
env .close ()
251
251
252
252
253
- @pytest .mark .skipif (not _has_gym , reason = "no gym" )
254
- @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED , PONG_VERSIONED ])
255
- @pytest .mark .parametrize ("frame_skip" , [1 , 4 ])
256
- def test_rollout (env_name , frame_skip , seed = 0 ):
257
- if env_name is PONG_VERSIONED and version .parse (
258
- gym_backend ().__version__
259
- ) < version .parse ("0.19" ):
260
- # Then 100 steps in pong are not sufficient to detect a difference
261
- pytest .skip ("can't detect difference in gym rollout with this gym version." )
253
+ class TestRollout :
254
+ @pytest .mark .skipif (not _has_gym , reason = "no gym" )
255
+ @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED , PONG_VERSIONED ])
256
+ @pytest .mark .parametrize ("frame_skip" , [1 , 4 ])
257
+ def test_rollout (self , env_name , frame_skip , seed = 0 ):
258
+ if env_name is PONG_VERSIONED and version .parse (
259
+ gym_backend ().__version__
260
+ ) < version .parse ("0.19" ):
261
+ # Then 100 steps in pong are not sufficient to detect a difference
262
+ pytest .skip ("can't detect difference in gym rollout with this gym version." )
262
263
263
- env_name = env_name ()
264
- env = GymEnv (env_name , frame_skip = frame_skip )
264
+ env_name = env_name ()
265
+ env = GymEnv (env_name , frame_skip = frame_skip )
265
266
266
- torch .manual_seed (seed )
267
- np .random .seed (seed )
268
- env .set_seed (seed )
269
- env .reset ()
270
- rollout1 = env .rollout (max_steps = 100 )
271
- assert rollout1 .names [- 1 ] == "time"
267
+ torch .manual_seed (seed )
268
+ np .random .seed (seed )
269
+ env .set_seed (seed )
270
+ env .reset ()
271
+ rollout1 = env .rollout (max_steps = 100 )
272
+ assert rollout1 .names [- 1 ] == "time"
272
273
273
- torch .manual_seed (seed )
274
- np .random .seed (seed )
275
- env .set_seed (seed )
276
- env .reset ()
277
- rollout2 = env .rollout (max_steps = 100 )
278
- assert rollout2 .names [- 1 ] == "time"
274
+ torch .manual_seed (seed )
275
+ np .random .seed (seed )
276
+ env .set_seed (seed )
277
+ env .reset ()
278
+ rollout2 = env .rollout (max_steps = 100 )
279
+ assert rollout2 .names [- 1 ] == "time"
279
280
280
- assert_allclose_td (rollout1 , rollout2 )
281
+ assert_allclose_td (rollout1 , rollout2 )
281
282
282
- torch .manual_seed (seed )
283
- env .set_seed (seed + 10 )
284
- env .reset ()
285
- rollout3 = env .rollout (max_steps = 100 )
286
- with pytest .raises (AssertionError ):
287
- assert_allclose_td (rollout1 , rollout3 )
288
- env .close ()
283
+ torch .manual_seed (seed )
284
+ env .set_seed (seed + 10 )
285
+ env .reset ()
286
+ rollout3 = env .rollout (max_steps = 100 )
287
+ with pytest .raises (AssertionError ):
288
+ assert_allclose_td (rollout1 , rollout3 )
289
+ env .close ()
289
290
291
+ def test_rollout_set_truncated (self ):
292
+ env = ContinuousActionVecMockEnv ()
293
+ with pytest .raises (RuntimeError , match = "set_truncated was set to True" ):
294
+ env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
295
+ env .add_truncated_keys ()
296
+ r = env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
297
+ assert r .shape == torch .Size ([10 ])
298
+ assert r [..., - 1 ]["next" , "truncated" ].all ()
299
+ assert r [..., - 1 ]["next" , "done" ].all ()
300
+
301
+ @pytest .mark .parametrize ("max_steps" , [1 , 5 ])
302
+ def test_rollouts_chaining (self , max_steps , batch_size = (4 ,), epochs = 4 ):
303
+ # CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
304
+ env = CountingEnv (max_steps = max_steps - 1 , batch_size = batch_size )
305
+ policy = CountingEnvCountPolicy (
306
+ action_spec = env .action_spec , action_key = env .action_key
307
+ )
308
+
309
+ input_td = env .reset ()
310
+ for _ in range (epochs ):
311
+ rollout_td = env .rollout (
312
+ max_steps = max_steps ,
313
+ policy = policy ,
314
+ auto_reset = False ,
315
+ break_when_any_done = False ,
316
+ tensordict = input_td ,
317
+ )
318
+ assert (env .count == max_steps ).all ()
319
+ input_td = step_mdp (
320
+ rollout_td [..., - 1 ],
321
+ keep_other = True ,
322
+ exclude_action = False ,
323
+ exclude_reward = True ,
324
+ reward_keys = env .reward_keys ,
325
+ action_keys = env .action_keys ,
326
+ done_keys = env .done_keys ,
327
+ )
290
328
291
- def test_rollout_set_truncated ():
292
- env = ContinuousActionVecMockEnv ()
293
- with pytest .raises (RuntimeError , match = "set_truncated was set to True" ):
294
- env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
295
- env .add_truncated_keys ()
296
- r = env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
297
- assert r .shape == torch .Size ([10 ])
298
- assert r [..., - 1 ]["next" , "truncated" ].all ()
299
- assert r [..., - 1 ]["next" , "done" ].all ()
300
-
301
-
302
- @pytest .mark .parametrize ("max_steps" , [1 , 5 ])
303
- def test_rollouts_chaining (max_steps , batch_size = (4 ,), epochs = 4 ):
304
- # CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
305
- env = CountingEnv (max_steps = max_steps - 1 , batch_size = batch_size )
306
- policy = CountingEnvCountPolicy (
307
- action_spec = env .action_spec , action_key = env .action_key
308
- )
329
+ @pytest .mark .parametrize ("device" , get_default_devices ())
330
+ def test_rollout_predictability (self , device ):
331
+ env = MockSerialEnv (device = device )
332
+ env .set_seed (100 )
333
+ first = 100 % 17
334
+ policy = Actor (torch .nn .Linear (1 , 1 , bias = False )).to (device )
335
+ for p in policy .parameters ():
336
+ p .data .fill_ (1.0 )
337
+ td_out = env .rollout (policy = policy , max_steps = 200 )
338
+ assert (
339
+ torch .arange (first , first + 100 , device = device )
340
+ == td_out .get ("observation" ).squeeze ()
341
+ ).all ()
342
+ assert (
343
+ torch .arange (first + 1 , first + 101 , device = device )
344
+ == td_out .get (("next" , "observation" )).squeeze ()
345
+ ).all ()
346
+ assert (
347
+ torch .arange (first + 1 , first + 101 , device = device )
348
+ == td_out .get (("next" , "reward" )).squeeze ()
349
+ ).all ()
350
+ assert (
351
+ torch .arange (first , first + 100 , device = device )
352
+ == td_out .get ("action" ).squeeze ()
353
+ ).all ()
309
354
310
- input_td = env .reset ()
311
- for _ in range (epochs ):
312
- rollout_td = env .rollout (
313
- max_steps = max_steps ,
314
- policy = policy ,
315
- auto_reset = False ,
316
- break_when_any_done = False ,
317
- tensordict = input_td ,
318
- )
319
- assert (env .count == max_steps ).all ()
320
- input_td = step_mdp (
321
- rollout_td [..., - 1 ],
322
- keep_other = True ,
323
- exclude_action = False ,
324
- exclude_reward = True ,
325
- reward_keys = env .reward_keys ,
326
- action_keys = env .action_keys ,
327
- done_keys = env .done_keys ,
328
- )
355
+ @pytest .mark .skipif (not _has_gym , reason = "no gym" )
356
+ @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED ])
357
+ @pytest .mark .parametrize ("frame_skip" , [1 ])
358
+ @pytest .mark .parametrize ("truncated_key" , ["truncated" , "done" ])
359
+ @pytest .mark .parametrize ("parallel" , [False , True ])
360
+ def test_rollout_reset (
361
+ self ,
362
+ env_name ,
363
+ frame_skip ,
364
+ parallel ,
365
+ truncated_key ,
366
+ maybe_fork_ParallelEnv ,
367
+ seed = 0 ,
368
+ ):
369
+ env_name = env_name ()
370
+ envs = []
371
+ for horizon in [20 , 30 , 40 ]:
372
+ envs .append (
373
+ lambda horizon = horizon : TransformedEnv (
374
+ GymEnv (env_name , frame_skip = frame_skip ),
375
+ StepCounter (horizon , truncated_key = truncated_key ),
376
+ )
377
+ )
378
+ if parallel :
379
+ env = maybe_fork_ParallelEnv (3 , envs )
380
+ else :
381
+ env = SerialEnv (3 , envs )
382
+ env .set_seed (100 )
383
+ out = env .rollout (100 , break_when_any_done = False )
384
+ assert out .names [- 1 ] == "time"
385
+ assert out .shape == torch .Size ([3 , 100 ])
386
+ assert (
387
+ out [..., - 1 ]["step_count" ].squeeze ().cpu () == torch .tensor ([19 , 9 , 19 ])
388
+ ).all ()
389
+ assert (
390
+ out [..., - 1 ]["next" , "step_count" ].squeeze ().cpu ()
391
+ == torch .tensor ([20 , 10 , 20 ])
392
+ ).all ()
393
+ assert (
394
+ out ["next" , truncated_key ].squeeze ().sum (- 1 ) == torch .tensor ([5 , 3 , 2 ])
395
+ ).all ()
329
396
397
+ @pytest .mark .parametrize (
398
+ "break_when_any_done,break_when_all_done" ,
399
+ [[True , False ], [False , True ], [False , False ]],
400
+ )
401
+ @pytest .mark .parametrize ("n_envs,serial" , [[1 , None ], [4 , True ], [4 , False ]])
402
+ def test_rollout_outplace_policy (
403
+ self , n_envs , serial , break_when_any_done , break_when_all_done
404
+ ):
405
+ def policy_inplace (td ):
406
+ td .set ("action" , torch .ones (td .shape + (1 ,)))
407
+ return td
330
408
331
- @pytest .mark .parametrize ("device" , get_default_devices ())
332
- def test_rollout_predictability (device ):
333
- env = MockSerialEnv (device = device )
334
- env .set_seed (100 )
335
- first = 100 % 17
336
- policy = Actor (torch .nn .Linear (1 , 1 , bias = False )).to (device )
337
- for p in policy .parameters ():
338
- p .data .fill_ (1.0 )
339
- td_out = env .rollout (policy = policy , max_steps = 200 )
340
- assert (
341
- torch .arange (first , first + 100 , device = device )
342
- == td_out .get ("observation" ).squeeze ()
343
- ).all ()
344
- assert (
345
- torch .arange (first + 1 , first + 101 , device = device )
346
- == td_out .get (("next" , "observation" )).squeeze ()
347
- ).all ()
348
- assert (
349
- torch .arange (first + 1 , first + 101 , device = device )
350
- == td_out .get (("next" , "reward" )).squeeze ()
351
- ).all ()
352
- assert (
353
- torch .arange (first , first + 100 , device = device )
354
- == td_out .get ("action" ).squeeze ()
355
- ).all ()
409
+ def policy_outplace (td ):
410
+ return td .empty ().set ("action" , torch .ones (td .shape + (1 ,)))
411
+
412
+ if n_envs == 1 :
413
+ env = CountingEnv (10 )
414
+ elif serial :
415
+ env = SerialEnv (
416
+ n_envs ,
417
+ [partial (CountingEnv , 10 + i ) for i in range (n_envs )],
418
+ )
419
+ else :
420
+ env = ParallelEnv (
421
+ n_envs ,
422
+ [partial (CountingEnv , 10 + i ) for i in range (n_envs )],
423
+ mp_start_method = mp_ctx ,
424
+ )
425
+ r_inplace = env .rollout (
426
+ 40 ,
427
+ policy_inplace ,
428
+ break_when_all_done = break_when_all_done ,
429
+ break_when_any_done = break_when_any_done ,
430
+ )
431
+ r_outplace = env .rollout (
432
+ 40 ,
433
+ policy_outplace ,
434
+ break_when_all_done = break_when_all_done ,
435
+ break_when_any_done = break_when_any_done ,
436
+ )
437
+ if break_when_any_done :
438
+ assert r_outplace .shape [- 1 :] == (11 ,)
439
+ elif break_when_all_done :
440
+ if n_envs > 1 :
441
+ assert r_outplace .shape [- 1 :] == (14 ,)
442
+ else :
443
+ assert r_outplace .shape [- 1 :] == (11 ,)
444
+ else :
445
+ assert r_outplace .shape [- 1 :] == (40 ,)
446
+ assert_allclose_td (r_inplace , r_outplace )
356
447
357
448
358
449
# Check that the "terminated" key is filled in automatically if only the "done"
@@ -411,42 +502,6 @@ def _step(
411
502
assert torch .equal (td [("next" , "terminated" )], torch .tensor ([[True ], [False ]]))
412
503
413
504
414
- @pytest .mark .skipif (not _has_gym , reason = "no gym" )
415
- @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED ])
416
- @pytest .mark .parametrize ("frame_skip" , [1 ])
417
- @pytest .mark .parametrize ("truncated_key" , ["truncated" , "done" ])
418
- @pytest .mark .parametrize ("parallel" , [False , True ])
419
- def test_rollout_reset (
420
- env_name , frame_skip , parallel , truncated_key , maybe_fork_ParallelEnv , seed = 0
421
- ):
422
- env_name = env_name ()
423
- envs = []
424
- for horizon in [20 , 30 , 40 ]:
425
- envs .append (
426
- lambda horizon = horizon : TransformedEnv (
427
- GymEnv (env_name , frame_skip = frame_skip ),
428
- StepCounter (horizon , truncated_key = truncated_key ),
429
- )
430
- )
431
- if parallel :
432
- env = maybe_fork_ParallelEnv (3 , envs )
433
- else :
434
- env = SerialEnv (3 , envs )
435
- env .set_seed (100 )
436
- out = env .rollout (100 , break_when_any_done = False )
437
- assert out .names [- 1 ] == "time"
438
- assert out .shape == torch .Size ([3 , 100 ])
439
- assert (
440
- out [..., - 1 ]["step_count" ].squeeze ().cpu () == torch .tensor ([19 , 9 , 19 ])
441
- ).all ()
442
- assert (
443
- out [..., - 1 ]["next" , "step_count" ].squeeze ().cpu () == torch .tensor ([20 , 10 , 20 ])
444
- ).all ()
445
- assert (
446
- out ["next" , truncated_key ].squeeze ().sum (- 1 ) == torch .tensor ([5 , 3 , 2 ])
447
- ).all ()
448
-
449
-
450
505
class TestModelBasedEnvBase :
451
506
@staticmethod
452
507
def world_model ():
0 commit comments