Skip to content

Commit f6084b6

Browse files
committed
[BugFix] Fix collector timeouts
ghstack-source-id: cb71d95143beb22db1fe1752e72f70c19f43be79 Pull Request resolved: #2774
1 parent 3da2750 commit f6084b6

File tree

3 files changed

+1801
-1642
lines changed

3 files changed

+1801
-1642
lines changed

test/mocking_classes.py

+26
Original file line numberDiff line numberDiff line change
@@ -2242,3 +2242,29 @@ def _set_seed(self, seed: Optional[int]):
22422242
random.seed(seed)
22432243
torch.manual_seed(0)
22442244
return seed
2245+
2246+
2247+
class EnvThatErrorsAfter10Iters(EnvBase):
2248+
def __init__(self):
2249+
self.action_spec = Composite(action=Unbounded((1,)))
2250+
self.reward_spec = Composite(reward=Unbounded((1,)))
2251+
self.done_spec = Composite(done=Unbounded((1,)))
2252+
self.observation_spec = Composite(observation=Unbounded((1,)))
2253+
self.counter = 0
2254+
super().__init__()
2255+
2256+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
2257+
return self.full_observation_spec.zero().update(self.full_done_spec.zero())
2258+
2259+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
2260+
if self.counter >= 10:
2261+
raise RuntimeError("max steps!")
2262+
self.counter += 1
2263+
return (
2264+
self.full_observation_spec.zero()
2265+
.update(self.full_done_spec.zero())
2266+
.update(self.full_reward_spec.zero())
2267+
)
2268+
2269+
def _set_seed(self, seed: Optional[int]):
2270+
...

0 commit comments

Comments
 (0)