Skip to content

Commit e4076f4

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1024d61 + 3365f01 commit e4076f4

File tree

6 files changed

+313
-138
lines changed

6 files changed

+313
-138
lines changed

test/test_collector.py

+33
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,39 @@ def make_env():
690690
del env
691691

692692

693+
@pytest.mark.parametrize(
694+
"break_when_any_done,break_when_all_done",
695+
[[True, False], [False, True], [False, False]],
696+
)
697+
@pytest.mark.parametrize("n_envs", [1, 4])
698+
def test_collector_outplace_policy(n_envs, break_when_any_done, break_when_all_done):
699+
def policy_inplace(td):
700+
td.set("action", torch.ones(td.shape + (1,)))
701+
return td
702+
703+
def policy_outplace(td):
704+
return td.empty().set("action", torch.ones(td.shape + (1,)))
705+
706+
if n_envs == 1:
707+
env = CountingEnv(10)
708+
else:
709+
env = SerialEnv(
710+
n_envs,
711+
[functools.partial(CountingEnv, 10 + i) for i in range(n_envs)],
712+
)
713+
env.reset()
714+
c_inplace = SyncDataCollector(
715+
env, policy_inplace, frames_per_batch=10, total_frames=100
716+
)
717+
d_inplace = torch.cat(list(c_inplace), dim=0)
718+
env.reset()
719+
c_outplace = SyncDataCollector(
720+
env, policy_outplace, frames_per_batch=10, total_frames=100
721+
)
722+
d_outplace = torch.cat(list(c_outplace), dim=0)
723+
assert_allclose_td(d_inplace, d_outplace)
724+
725+
693726
# Deprecated reset_when_done
694727
# @pytest.mark.parametrize("num_env", [1, 2])
695728
# @pytest.mark.parametrize("env_name", ["vec"])

test/test_env.py

+184-129
Original file line numberDiff line numberDiff line change
@@ -250,109 +250,200 @@ def test_env_seed(env_name, frame_skip, seed=0):
250250
env.close()
251251

252252

253-
@pytest.mark.skipif(not _has_gym, reason="no gym")
254-
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED])
255-
@pytest.mark.parametrize("frame_skip", [1, 4])
256-
def test_rollout(env_name, frame_skip, seed=0):
257-
if env_name is PONG_VERSIONED and version.parse(
258-
gym_backend().__version__
259-
) < version.parse("0.19"):
260-
# Then 100 steps in pong are not sufficient to detect a difference
261-
pytest.skip("can't detect difference in gym rollout with this gym version.")
253+
class TestRollout:
254+
@pytest.mark.skipif(not _has_gym, reason="no gym")
255+
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED])
256+
@pytest.mark.parametrize("frame_skip", [1, 4])
257+
def test_rollout(self, env_name, frame_skip, seed=0):
258+
if env_name is PONG_VERSIONED and version.parse(
259+
gym_backend().__version__
260+
) < version.parse("0.19"):
261+
# Then 100 steps in pong are not sufficient to detect a difference
262+
pytest.skip("can't detect difference in gym rollout with this gym version.")
262263

263-
env_name = env_name()
264-
env = GymEnv(env_name, frame_skip=frame_skip)
264+
env_name = env_name()
265+
env = GymEnv(env_name, frame_skip=frame_skip)
265266

266-
torch.manual_seed(seed)
267-
np.random.seed(seed)
268-
env.set_seed(seed)
269-
env.reset()
270-
rollout1 = env.rollout(max_steps=100)
271-
assert rollout1.names[-1] == "time"
267+
torch.manual_seed(seed)
268+
np.random.seed(seed)
269+
env.set_seed(seed)
270+
env.reset()
271+
rollout1 = env.rollout(max_steps=100)
272+
assert rollout1.names[-1] == "time"
272273

273-
torch.manual_seed(seed)
274-
np.random.seed(seed)
275-
env.set_seed(seed)
276-
env.reset()
277-
rollout2 = env.rollout(max_steps=100)
278-
assert rollout2.names[-1] == "time"
274+
torch.manual_seed(seed)
275+
np.random.seed(seed)
276+
env.set_seed(seed)
277+
env.reset()
278+
rollout2 = env.rollout(max_steps=100)
279+
assert rollout2.names[-1] == "time"
279280

280-
assert_allclose_td(rollout1, rollout2)
281+
assert_allclose_td(rollout1, rollout2)
281282

282-
torch.manual_seed(seed)
283-
env.set_seed(seed + 10)
284-
env.reset()
285-
rollout3 = env.rollout(max_steps=100)
286-
with pytest.raises(AssertionError):
287-
assert_allclose_td(rollout1, rollout3)
288-
env.close()
283+
torch.manual_seed(seed)
284+
env.set_seed(seed + 10)
285+
env.reset()
286+
rollout3 = env.rollout(max_steps=100)
287+
with pytest.raises(AssertionError):
288+
assert_allclose_td(rollout1, rollout3)
289+
env.close()
289290

291+
def test_rollout_set_truncated(self):
292+
env = ContinuousActionVecMockEnv()
293+
with pytest.raises(RuntimeError, match="set_truncated was set to True"):
294+
env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
295+
env.add_truncated_keys()
296+
r = env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
297+
assert r.shape == torch.Size([10])
298+
assert r[..., -1]["next", "truncated"].all()
299+
assert r[..., -1]["next", "done"].all()
300+
301+
@pytest.mark.parametrize("max_steps", [1, 5])
302+
def test_rollouts_chaining(self, max_steps, batch_size=(4,), epochs=4):
303+
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
304+
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
305+
policy = CountingEnvCountPolicy(
306+
action_spec=env.action_spec, action_key=env.action_key
307+
)
308+
309+
input_td = env.reset()
310+
for _ in range(epochs):
311+
rollout_td = env.rollout(
312+
max_steps=max_steps,
313+
policy=policy,
314+
auto_reset=False,
315+
break_when_any_done=False,
316+
tensordict=input_td,
317+
)
318+
assert (env.count == max_steps).all()
319+
input_td = step_mdp(
320+
rollout_td[..., -1],
321+
keep_other=True,
322+
exclude_action=False,
323+
exclude_reward=True,
324+
reward_keys=env.reward_keys,
325+
action_keys=env.action_keys,
326+
done_keys=env.done_keys,
327+
)
290328

291-
def test_rollout_set_truncated():
292-
env = ContinuousActionVecMockEnv()
293-
with pytest.raises(RuntimeError, match="set_truncated was set to True"):
294-
env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
295-
env.add_truncated_keys()
296-
r = env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
297-
assert r.shape == torch.Size([10])
298-
assert r[..., -1]["next", "truncated"].all()
299-
assert r[..., -1]["next", "done"].all()
300-
301-
302-
@pytest.mark.parametrize("max_steps", [1, 5])
303-
def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4):
304-
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
305-
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
306-
policy = CountingEnvCountPolicy(
307-
action_spec=env.action_spec, action_key=env.action_key
308-
)
329+
@pytest.mark.parametrize("device", get_default_devices())
330+
def test_rollout_predictability(self, device):
331+
env = MockSerialEnv(device=device)
332+
env.set_seed(100)
333+
first = 100 % 17
334+
policy = Actor(torch.nn.Linear(1, 1, bias=False)).to(device)
335+
for p in policy.parameters():
336+
p.data.fill_(1.0)
337+
td_out = env.rollout(policy=policy, max_steps=200)
338+
assert (
339+
torch.arange(first, first + 100, device=device)
340+
== td_out.get("observation").squeeze()
341+
).all()
342+
assert (
343+
torch.arange(first + 1, first + 101, device=device)
344+
== td_out.get(("next", "observation")).squeeze()
345+
).all()
346+
assert (
347+
torch.arange(first + 1, first + 101, device=device)
348+
== td_out.get(("next", "reward")).squeeze()
349+
).all()
350+
assert (
351+
torch.arange(first, first + 100, device=device)
352+
== td_out.get("action").squeeze()
353+
).all()
309354

310-
input_td = env.reset()
311-
for _ in range(epochs):
312-
rollout_td = env.rollout(
313-
max_steps=max_steps,
314-
policy=policy,
315-
auto_reset=False,
316-
break_when_any_done=False,
317-
tensordict=input_td,
318-
)
319-
assert (env.count == max_steps).all()
320-
input_td = step_mdp(
321-
rollout_td[..., -1],
322-
keep_other=True,
323-
exclude_action=False,
324-
exclude_reward=True,
325-
reward_keys=env.reward_keys,
326-
action_keys=env.action_keys,
327-
done_keys=env.done_keys,
328-
)
355+
@pytest.mark.skipif(not _has_gym, reason="no gym")
356+
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED])
357+
@pytest.mark.parametrize("frame_skip", [1])
358+
@pytest.mark.parametrize("truncated_key", ["truncated", "done"])
359+
@pytest.mark.parametrize("parallel", [False, True])
360+
def test_rollout_reset(
361+
self,
362+
env_name,
363+
frame_skip,
364+
parallel,
365+
truncated_key,
366+
maybe_fork_ParallelEnv,
367+
seed=0,
368+
):
369+
env_name = env_name()
370+
envs = []
371+
for horizon in [20, 30, 40]:
372+
envs.append(
373+
lambda horizon=horizon: TransformedEnv(
374+
GymEnv(env_name, frame_skip=frame_skip),
375+
StepCounter(horizon, truncated_key=truncated_key),
376+
)
377+
)
378+
if parallel:
379+
env = maybe_fork_ParallelEnv(3, envs)
380+
else:
381+
env = SerialEnv(3, envs)
382+
env.set_seed(100)
383+
out = env.rollout(100, break_when_any_done=False)
384+
assert out.names[-1] == "time"
385+
assert out.shape == torch.Size([3, 100])
386+
assert (
387+
out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19])
388+
).all()
389+
assert (
390+
out[..., -1]["next", "step_count"].squeeze().cpu()
391+
== torch.tensor([20, 10, 20])
392+
).all()
393+
assert (
394+
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
395+
).all()
329396

397+
@pytest.mark.parametrize(
398+
"break_when_any_done,break_when_all_done",
399+
[[True, False], [False, True], [False, False]],
400+
)
401+
@pytest.mark.parametrize("n_envs,serial", [[1, None], [4, True], [4, False]])
402+
def test_rollout_outplace_policy(
403+
self, n_envs, serial, break_when_any_done, break_when_all_done
404+
):
405+
def policy_inplace(td):
406+
td.set("action", torch.ones(td.shape + (1,)))
407+
return td
330408

331-
@pytest.mark.parametrize("device", get_default_devices())
332-
def test_rollout_predictability(device):
333-
env = MockSerialEnv(device=device)
334-
env.set_seed(100)
335-
first = 100 % 17
336-
policy = Actor(torch.nn.Linear(1, 1, bias=False)).to(device)
337-
for p in policy.parameters():
338-
p.data.fill_(1.0)
339-
td_out = env.rollout(policy=policy, max_steps=200)
340-
assert (
341-
torch.arange(first, first + 100, device=device)
342-
== td_out.get("observation").squeeze()
343-
).all()
344-
assert (
345-
torch.arange(first + 1, first + 101, device=device)
346-
== td_out.get(("next", "observation")).squeeze()
347-
).all()
348-
assert (
349-
torch.arange(first + 1, first + 101, device=device)
350-
== td_out.get(("next", "reward")).squeeze()
351-
).all()
352-
assert (
353-
torch.arange(first, first + 100, device=device)
354-
== td_out.get("action").squeeze()
355-
).all()
409+
def policy_outplace(td):
410+
return td.empty().set("action", torch.ones(td.shape + (1,)))
411+
412+
if n_envs == 1:
413+
env = CountingEnv(10)
414+
elif serial:
415+
env = SerialEnv(
416+
n_envs,
417+
[partial(CountingEnv, 10 + i) for i in range(n_envs)],
418+
)
419+
else:
420+
env = ParallelEnv(
421+
n_envs,
422+
[partial(CountingEnv, 10 + i) for i in range(n_envs)],
423+
mp_start_method=mp_ctx,
424+
)
425+
r_inplace = env.rollout(
426+
40,
427+
policy_inplace,
428+
break_when_all_done=break_when_all_done,
429+
break_when_any_done=break_when_any_done,
430+
)
431+
r_outplace = env.rollout(
432+
40,
433+
policy_outplace,
434+
break_when_all_done=break_when_all_done,
435+
break_when_any_done=break_when_any_done,
436+
)
437+
if break_when_any_done:
438+
assert r_outplace.shape[-1:] == (11,)
439+
elif break_when_all_done:
440+
if n_envs > 1:
441+
assert r_outplace.shape[-1:] == (14,)
442+
else:
443+
assert r_outplace.shape[-1:] == (11,)
444+
else:
445+
assert r_outplace.shape[-1:] == (40,)
446+
assert_allclose_td(r_inplace, r_outplace)
356447

357448

358449
# Check that the "terminated" key is filled in automatically if only the "done"
@@ -411,42 +502,6 @@ def _step(
411502
assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
412503

413504

414-
@pytest.mark.skipif(not _has_gym, reason="no gym")
415-
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED])
416-
@pytest.mark.parametrize("frame_skip", [1])
417-
@pytest.mark.parametrize("truncated_key", ["truncated", "done"])
418-
@pytest.mark.parametrize("parallel", [False, True])
419-
def test_rollout_reset(
420-
env_name, frame_skip, parallel, truncated_key, maybe_fork_ParallelEnv, seed=0
421-
):
422-
env_name = env_name()
423-
envs = []
424-
for horizon in [20, 30, 40]:
425-
envs.append(
426-
lambda horizon=horizon: TransformedEnv(
427-
GymEnv(env_name, frame_skip=frame_skip),
428-
StepCounter(horizon, truncated_key=truncated_key),
429-
)
430-
)
431-
if parallel:
432-
env = maybe_fork_ParallelEnv(3, envs)
433-
else:
434-
env = SerialEnv(3, envs)
435-
env.set_seed(100)
436-
out = env.rollout(100, break_when_any_done=False)
437-
assert out.names[-1] == "time"
438-
assert out.shape == torch.Size([3, 100])
439-
assert (
440-
out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19])
441-
).all()
442-
assert (
443-
out[..., -1]["next", "step_count"].squeeze().cpu() == torch.tensor([20, 10, 20])
444-
).all()
445-
assert (
446-
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
447-
).all()
448-
449-
450505
class TestModelBasedEnvBase:
451506
@staticmethod
452507
def world_model():

0 commit comments

Comments
 (0)