Skip to content

Commit

Permalink
[BugFix] Avoid calling reset during env init
Browse files Browse the repository at this point in the history
ghstack-source-id: 5ab8281c34aacfd7dbbfc0e285d88bcae0aededf
Pull Request resolved: #2770

(cherry picked from commit 09e93c1)
  • Loading branch information
vmoens committed Feb 10, 2025
1 parent 6b0d5b8 commit 28c3c7a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 48 deletions.
90 changes: 60 additions & 30 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
from torchrl.modules.tensordict_module import WorldModelWrapper

pytestmark = [
pytest.mark.filterwarnings("error"),
pytest.mark.filterwarnings(
"ignore:Got multiple backends for torchrl.data.replay_buffers.storages"
),
]

gym_version = None
if _has_gym:
try:
Expand Down Expand Up @@ -232,7 +239,7 @@ def test_run_type_checks(self):
check_env_specs(env)
env._run_type_checks = True
check_env_specs(env)
env.output_spec.unlock_()
env.output_spec.unlock_(recurse=True)
# check type check on done
env.output_spec["full_done_spec", "done"].dtype = torch.int
with pytest.raises(TypeError, match="expected done.dtype to"):
Expand Down Expand Up @@ -292,8 +299,8 @@ def test_single_env_spec(self):
assert not env.output_spec_unbatched.shape
assert not env.full_reward_spec_unbatched.shape

assert env.action_spec_unbatched.shape
assert env.reward_spec_unbatched.shape
assert env.full_action_spec_unbatched[env.action_key].shape
assert env.full_reward_spec_unbatched[env.reward_key].shape

assert env.output_spec.is_in(env.output_spec_unbatched.zeros(env.shape))
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
Expand All @@ -307,7 +314,10 @@ def forward(self, values):
return values.argmax(-1)

policy = nn.Sequential(
nn.Linear(env.observation_spec["observation"].shape[-1], env.action_spec.n),
nn.Linear(
env.observation_spec["observation"].shape[-1],
env.full_action_spec[env.action_key].n,
),
ArgMaxModule(),
)
env.rollout(10, policy)
Expand Down Expand Up @@ -507,7 +517,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
policy = Actor(
nn.Linear(
env.observation_spec["observation"].shape[-1],
env.action_spec.shape[-1],
env.full_action_spec[env.action_key].shape[-1],
device="cuda:0",
),
in_keys=["observation"],
Expand Down Expand Up @@ -538,7 +548,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
def test_env_seed(self, env_name, frame_skip, seed=0):
env_name = env_name()
env = GymEnv(env_name, frame_skip=frame_skip)
action = env.action_spec.rand()
action = env.full_action_spec[env.action_key].rand()

env.set_seed(seed)
td0a = env.reset()
Expand Down Expand Up @@ -624,7 +634,7 @@ def test_env_base_reset_flag(self, batch_size, max_steps=3):
env = CountingEnv(max_steps=max_steps, batch_size=batch_size)
env.set_seed(1)

action = env.action_spec.rand()
action = env.full_action_spec[env.action_key].rand()
action[:] = 1

for i in range(max_steps):
Expand Down Expand Up @@ -695,7 +705,7 @@ def test_batch_locked(self, device):
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
env.batch_locked = False
td = env.reset()
td["action"] = env.action_spec.rand()
td["action"] = env.full_action_spec[env.action_key].rand()
td_expanded = td.expand(2).clone()
_ = env.step(td)

Expand All @@ -712,7 +722,7 @@ def test_batch_unlocked(self, device):
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
env.batch_locked = False
td = env.reset()
td["action"] = env.action_spec.rand()
td["action"] = env.full_action_spec[env.action_key].rand()
td_expanded = td.expand(2).clone()
td = env.step(td)

Expand All @@ -727,7 +737,7 @@ def test_batch_unlocked_with_batch_size(self, device):
env.batch_locked = False

td = env.reset()
td["action"] = env.action_spec.rand()
td["action"] = env.full_action_spec[env.action_key].rand()
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()
td = env.step(td)

Expand Down Expand Up @@ -803,7 +813,7 @@ def test_rollouts_chaining(self, max_steps, batch_size=(4,), epochs=4):
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
policy = CountingEnvCountPolicy(
action_spec=env.action_spec, action_key=env.action_key
action_spec=env.full_action_spec[env.action_key], action_key=env.action_key
)

input_td = env.reset()
Expand Down Expand Up @@ -1010,7 +1020,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
mb_env.batch_locked = False
td = mb_env.reset()
td["action"] = mb_env.action_spec.rand()
td["action"] = mb_env.full_action_spec[mb_env.action_key].rand()
td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict()
mb_env.step(td)

Expand All @@ -1028,7 +1038,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
mb_env.batch_locked = False
td = mb_env.reset()
td["action"] = mb_env.action_spec.rand()
td["action"] = mb_env.full_action_spec[mb_env.action_key].rand()
td_expanded = td.expand(2)
mb_env.step(td)
# we should be able to do a step with a tensordict that has been expended
Expand Down Expand Up @@ -1242,6 +1252,7 @@ def test_parallel_env(
N=N,
)
td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N])
env_parallel.reset()
td1 = env_parallel.step(td)
assert not td1.is_shared()
assert ("next", "done") in td1.keys(True)
Expand Down Expand Up @@ -1308,6 +1319,7 @@ def test_parallel_env_with_policy(
)

td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N])
env_parallel.reset()
td1 = env_parallel.step(td)
assert not td1.is_shared()
assert ("next", "done") in td1.keys(True)
Expand Down Expand Up @@ -1715,7 +1727,7 @@ def test_parallel_env_reset_flag(
n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size)
)
env.set_seed(1)
action = env.action_spec.rand()
action = env.full_action_spec[env.action_key].rand()
action[:] = 1
for i in range(max_steps):
td = env.step(
Expand Down Expand Up @@ -1787,7 +1799,9 @@ def test_parallel_env_nested(
if not nested_done and not nested_reward and not nested_obs_action:
assert "data" not in td.keys()

policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
policy = CountingEnvCountPolicy(
env.full_action_spec[env.action_key], env.action_key
)
td = env.rollout(rollout_length, policy)
assert td.batch_size == (*batch_size, rollout_length)
if nested_done or nested_obs_action:
Expand Down Expand Up @@ -2558,6 +2572,7 @@ def main_collector(j, q=None):
total_frames=N * n_workers * 100,
storing_device=device,
device=device,
trust_policy=True,
cat_results=-1,
)
single_collectors = [
Expand All @@ -2567,6 +2582,7 @@ def main_collector(j, q=None):
frames_per_batch=n_workers * 100,
total_frames=N * n_workers * 100,
storing_device=device,
trust_policy=True,
device=device,
)
for i in range(n_workers)
Expand Down Expand Up @@ -2662,18 +2678,24 @@ def test_nested_env(self, envclass):
else:
raise NotImplementedError
reset = env.reset()
assert not isinstance(env.reward_spec, Composite)
with pytest.warns(
DeprecationWarning, match="non-trivial"
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
assert not isinstance(env.reward_spec, Composite)
for done_key in env.done_keys:
assert (
env.full_done_spec[done_key]
== env.output_spec[("full_done_spec", *_unravel_key_to_tuple(done_key))]
)
assert (
env.reward_spec
== env.output_spec[
("full_reward_spec", *_unravel_key_to_tuple(env.reward_key))
]
)
with pytest.warns(
DeprecationWarning, match="non-trivial"
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
assert (
env.reward_spec
== env.output_spec[
("full_reward_spec", *_unravel_key_to_tuple(env.reward_key))
]
)
if envclass == "NestedCountingEnv":
for done_key in env.done_keys:
assert done_key in (("data", "done"), ("data", "terminated"))
Expand Down Expand Up @@ -2734,7 +2756,9 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
nested_dim,
)

policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
policy = CountingEnvCountPolicy(
env.full_action_spec[env.action_key], env.action_key
)
td = env.rollout(rollout_length, policy)
assert td.batch_size == (*batch_size, rollout_length)
assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim)
Expand Down Expand Up @@ -2858,7 +2882,7 @@ class TestMultiKeyEnvs:
@pytest.mark.parametrize("max_steps", [2, 5])
def test_rollout(self, batch_size, rollout_steps, max_steps, seed):
env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps)
policy = MultiKeyCountingEnvPolicy(full_action_spec=env.action_spec)
policy = MultiKeyCountingEnvPolicy(full_action_spec=env.full_action_spec)
td = env.rollout(rollout_steps, policy=policy)
torch.manual_seed(seed)
check_rollout_consistency_multikey_env(td, max_steps=max_steps)
Expand Down Expand Up @@ -2924,11 +2948,17 @@ def test_parallel(
)
def test_mocking_envs(envclass):
env = envclass()
env.set_seed(100)
with pytest.warns(UserWarning, match="model based") if isinstance(
env, DummyModelBasedEnvBase
) else contextlib.nullcontext():
env.set_seed(100)
reset = env.reset()
_ = env.rand_step(reset)
r = env.rollout(3)
check_env_specs(env, seed=100, return_contiguous=False)
with pytest.warns(UserWarning, match="model based") if isinstance(
env, DummyModelBasedEnvBase
) else contextlib.nullcontext():
check_env_specs(env, seed=100, return_contiguous=False)


class TestTerminatedOrTruncated:
Expand Down Expand Up @@ -4019,7 +4049,7 @@ def test_parallel_partial_steps(
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
Expand All @@ -4042,7 +4072,7 @@ def test_parallel_partial_step_and_maybe_reset(
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td.set("action", penv.full_action_spec[penv.action_key].one())
td, tdreset = penv.step_and_maybe_reset(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
Expand All @@ -4063,7 +4093,7 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
Expand All @@ -4084,7 +4114,7 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
Expand Down
47 changes: 34 additions & 13 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gc

import os
import time
import weakref
from collections import OrderedDict
from copy import deepcopy
Expand Down Expand Up @@ -1616,11 +1617,7 @@ def step_and_maybe_reset(
for i, _data in zip(workers_range, data):
self.parent_channels[i].send(("step_and_maybe_reset", _data))

for i in workers_range:
event = self._events[i]
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()

self._wait_for_workers(workers_range)
if self._non_tensor_keys:
non_tensor_tds = []
for i in workers_range:
Expand Down Expand Up @@ -1670,6 +1667,36 @@ def step_and_maybe_reset(

return tensordict, tensordict_

def _wait_for_workers(self, workers_range):
workers_range_consume = set(workers_range)
t0 = time.time()
while (
len(workers_range_consume)
and (time.time() - t0) < self.BATCHED_PIPE_TIMEOUT
):
for i in workers_range:
if i not in workers_range_consume:
continue
worker = self._workers[i]
if worker.is_alive():
event: mp.Event = self._events[i]
if event.is_set():
workers_range_consume.discard(i)
event.clear()
else:
continue
else:
try:
self._shutdown_workers()
finally:
raise RuntimeError(f"Cannot proceed, worker {i} dead.")
# event.wait(self.BATCHED_PIPE_TIMEOUT)
if len(workers_range_consume):
raise RuntimeError(
f"Failed to run all workers within the {self.BATCHED_PIPE_TIMEOUT} sec time limit. This "
f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
)

def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
Expand Down Expand Up @@ -1806,10 +1833,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for i in workers_range:
self.parent_channels[i].send(("step", data[i]))

for i in workers_range:
event = self._events[i]
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()
self._wait_for_workers(workers_range)

if self._non_tensor_keys:
non_tensor_tds = []
Expand Down Expand Up @@ -1975,10 +1999,7 @@ def tentative_update(val, other):
for i, out in outs:
self.parent_channels[i].send(out)

for i, _ in outs:
event = self._events[i]
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()
self._wait_for_workers(list(zip(*outs))[0])

workers_nontensor = []
if self._non_tensor_keys:
Expand Down
Loading

0 comments on commit 28c3c7a

Please sign in to comment.