Skip to content

Commit 69453a6

Browse files
author
Vincent Moens
authored
[BugFix] Fix flaky gym penv test (#1853)
1 parent 2754200 commit 69453a6

File tree

26 files changed

+78
-73
lines changed

26 files changed

+78
-73
lines changed

test/_utils_internal.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def rollout_consistency_assertion(
330330
):
331331
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""
332332

333-
done = rollout[:, :-1]["next", done_key].squeeze(-1)
333+
done = rollout[..., :-1]["next", done_key].squeeze(-1)
334334
# data resulting from step, when it's not done
335-
r_not_done = rollout[:, :-1]["next"][~done]
335+
r_not_done = rollout[..., :-1]["next"][~done]
336336
# data resulting from step, when it's not done, after step_mdp
337337
r_not_done_tp1 = rollout[:, 1:][~done]
338338
torch.testing.assert_close(
@@ -343,17 +343,15 @@ def rollout_consistency_assertion(
343343

344344
if done_strict and not done.any():
345345
raise RuntimeError("No done detected, test could not complete.")
346-
347-
# data resulting from step, when it's done
348-
r_done = rollout[:, :-1]["next"][done]
349-
# data resulting from step, when it's done, after step_mdp and reset
350-
r_done_tp1 = rollout[:, 1:][done]
351-
assert (
352-
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
353-
).all(), (
354-
f"Entries in next tensordict do not match entries in root "
355-
f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}"
356-
)
346+
if done.any():
347+
# data resulting from step, when it's done
348+
r_done = rollout[..., :-1]["next"][done]
349+
# data resulting from step, when it's done, after step_mdp and reset
350+
r_done_tp1 = rollout[..., 1:][done]
351+
# check that at least one obs after reset does not match the version before reset
352+
assert not torch.isclose(
353+
r_done[observation_key], r_done_tp1[observation_key]
354+
).all()
357355

358356

359357
def rand_reset(env):

torchrl/data/datasets/minari_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def _proc_spec(spec):
412412
)
413413
return BoundedTensorSpec(
414414
shape=spec["shape"],
415-
low=torch.tensor(spec["low"]),
416-
high=torch.tensor(spec["high"]),
415+
low=torch.as_tensor(spec["low"]),
416+
high=torch.as_tensor(spec["high"]),
417417
dtype=_DTYPE_DIR[spec["dtype"]],
418418
)
419419
elif spec["type"] == "Discrete":

torchrl/data/datasets/openx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value):
684684
truncated,
685685
dim=data.ndim - 1,
686686
value=True,
687-
index=torch.tensor(-1, device=truncated.device),
687+
index=torch.as_tensor(-1, device=truncated.device),
688688
)
689689
done = data.get(("next", "done"))
690690
data.set(("next", "truncated"), truncated)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ def add(self, data: TensorDictBase) -> int:
867867
device=data.device,
868868
)
869869
if data.batch_size:
870-
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)
870+
data_add["_rb_batch_size"] = torch.as_tensor(data.batch_size)
871871

872872
else:
873873
data_add = data
@@ -1441,7 +1441,7 @@ def __getitem__(
14411441
if isinstance(index, slice) and index == slice(None):
14421442
return self
14431443
if isinstance(index, (list, range, np.ndarray)):
1444-
index = torch.tensor(index)
1444+
index = torch.as_tensor(index)
14451445
if isinstance(index, torch.Tensor):
14461446
if index.ndim > 1:
14471447
raise RuntimeError(

torchrl/data/replay_buffers/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,10 @@ def dumps(self, path):
461461
filename=path / "mintree.memmap",
462462
)
463463
mm_st.copy_(
464-
torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)])
464+
torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)])
465465
)
466466
mm_mt.copy_(
467-
torch.tensor([self._min_tree[i] for i in range(self._max_capacity)])
467+
torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)])
468468
)
469469
with open(path / "sampler_metadata.json", "w") as file:
470470
json.dump(

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ def __getitem__(self, index):
10051005
if isinstance(index, slice) and index == slice(None):
10061006
return self
10071007
if isinstance(index, (list, range, np.ndarray)):
1008-
index = torch.tensor(index)
1008+
index = torch.as_tensor(index)
10091009
if isinstance(index, torch.Tensor):
10101010
if index.ndim > 1:
10111011
raise RuntimeError(

torchrl/data/replay_buffers/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def _to_torch(
2828
data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
2929
) -> torch.Tensor:
3030
if isinstance(data, np.generic):
31-
return torch.tensor(data, device=device)
31+
return torch.as_tensor(data, device=device)
3232
elif isinstance(data, np.ndarray):
3333
data = torch.from_numpy(data)
3434
elif not isinstance(data, Tensor):
35-
data = torch.tensor(data, device=device)
35+
data = torch.as_tensor(data, device=device)
3636

3737
if pin_memory:
3838
data = data.pin_memory()

torchrl/data/replay_buffers/writers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def __getstate__(self):
357357
def dumps(self, path):
358358
path = Path(path).absolute()
359359
path.mkdir(exist_ok=True)
360-
t = torch.tensor(self._current_top_values)
360+
t = torch.as_tensor(self._current_top_values)
361361
try:
362362
MemoryMappedTensor.from_filename(
363363
filename=path / "current_top_values.memmap",
@@ -453,7 +453,7 @@ def __getitem__(self, index):
453453
if isinstance(index, slice) and index == slice(None):
454454
return self
455455
if isinstance(index, (list, range, np.ndarray)):
456-
index = torch.tensor(index)
456+
index = torch.as_tensor(index)
457457
if isinstance(index, torch.Tensor):
458458
if index.ndim > 1:
459459
raise RuntimeError(

torchrl/data/rlhf/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def update(self, kl_values: Sequence[float]):
100100
)
101101
n_steps = len(kl_values)
102102
# renormalize kls
103-
kl_value = -torch.tensor(kl_values).mean() / self.coef
103+
kl_value = -torch.as_tensor(kl_values).mean() / self.coef
104104
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
105105
mult = 1 + proportional_error * n_steps / self.horizon
106106
self.coef *= mult # βₜ₊₁
@@ -314,10 +314,10 @@ def _get_done_status(self, generated, batch):
314314
# of generated tokens
315315
done_idx = torch.minimum(
316316
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
317-
torch.tensor(self.max_new_tokens) - 1,
317+
torch.as_tensor(self.max_new_tokens) - 1,
318318
)
319319
truncated_idx = (
320-
torch.tensor(self.max_new_tokens, device=generated.device).expand_as(
320+
torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
321321
done_idx
322322
)
323323
- 1

torchrl/data/tensor_specs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,9 +1374,9 @@ def encode(
13741374
) -> torch.Tensor:
13751375
if not isinstance(val, torch.Tensor):
13761376
if ignore_device:
1377-
val = torch.tensor(val)
1377+
val = torch.as_tensor(val)
13781378
else:
1379-
val = torch.tensor(val, device=self.device)
1379+
val = torch.as_tensor(val, device=self.device)
13801380

13811381
if space is None:
13821382
space = self.space
@@ -1555,9 +1555,9 @@ def __init__(
15551555
dtype = torch.get_default_dtype()
15561556

15571557
if not isinstance(low, torch.Tensor):
1558-
low = torch.tensor(low, dtype=dtype, device=device)
1558+
low = torch.as_tensor(low, dtype=dtype, device=device)
15591559
if not isinstance(high, torch.Tensor):
1560-
high = torch.tensor(high, dtype=dtype, device=device)
1560+
high = torch.as_tensor(high, dtype=dtype, device=device)
15611561
if high.device != device:
15621562
high = high.to(device)
15631563
if low.device != device:
@@ -1857,8 +1857,8 @@ def __init__(
18571857
dtype, device = _default_dtype_and_device(dtype, device)
18581858
box = (
18591859
ContinuousBox(
1860-
torch.tensor(-np.inf, device=device).expand(shape),
1861-
torch.tensor(np.inf, device=device).expand(shape),
1860+
torch.as_tensor(-np.inf, device=device).expand(shape),
1861+
torch.as_tensor(np.inf, device=device).expand(shape),
18621862
)
18631863
if shape == _DEFAULT_SHAPE
18641864
else None

0 commit comments

Comments
 (0)