Skip to content

Commit 69453a6

Browse files
authored
[BugFix] Fix flaky gym penv test (#1853)
1 parent 2754200 commit 69453a6

26 files changed

+78
-73
lines changed

test/_utils_internal.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def rollout_consistency_assertion(
330330
):
331331
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""
332332

333-
done = rollout[:, :-1]["next", done_key].squeeze(-1)
333+
done = rollout[..., :-1]["next", done_key].squeeze(-1)
334334
# data resulting from step, when it's not done
335-
r_not_done = rollout[:, :-1]["next"][~done]
335+
r_not_done = rollout[..., :-1]["next"][~done]
336336
# data resulting from step, when it's not done, after step_mdp
337337
r_not_done_tp1 = rollout[:, 1:][~done]
338338
torch.testing.assert_close(
@@ -343,17 +343,15 @@ def rollout_consistency_assertion(
343343

344344
if done_strict and not done.any():
345345
raise RuntimeError("No done detected, test could not complete.")
346-
347-
# data resulting from step, when it's done
348-
r_done = rollout[:, :-1]["next"][done]
349-
# data resulting from step, when it's done, after step_mdp and reset
350-
r_done_tp1 = rollout[:, 1:][done]
351-
assert (
352-
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
353-
).all(), (
354-
f"Entries in next tensordict do not match entries in root "
355-
f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}"
356-
)
346+
if done.any():
347+
# data resulting from step, when it's done
348+
r_done = rollout[..., :-1]["next"][done]
349+
# data resulting from step, when it's done, after step_mdp and reset
350+
r_done_tp1 = rollout[..., 1:][done]
351+
# check that at least one obs after reset does not match the version before reset
352+
assert not torch.isclose(
353+
r_done[observation_key], r_done_tp1[observation_key]
354+
).all()
357355

358356

359357
def rand_reset(env):

torchrl/data/datasets/minari_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def _proc_spec(spec):
412412
)
413413
return BoundedTensorSpec(
414414
shape=spec["shape"],
415-
low=torch.tensor(spec["low"]),
416-
high=torch.tensor(spec["high"]),
415+
low=torch.as_tensor(spec["low"]),
416+
high=torch.as_tensor(spec["high"]),
417417
dtype=_DTYPE_DIR[spec["dtype"]],
418418
)
419419
elif spec["type"] == "Discrete":

torchrl/data/datasets/openx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value):
684684
truncated,
685685
dim=data.ndim - 1,
686686
value=True,
687-
index=torch.tensor(-1, device=truncated.device),
687+
index=torch.as_tensor(-1, device=truncated.device),
688688
)
689689
done = data.get(("next", "done"))
690690
data.set(("next", "truncated"), truncated)

torchrl/data/replay_buffers/replay_buffers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ def add(self, data: TensorDictBase) -> int:
867867
device=data.device,
868868
)
869869
if data.batch_size:
870-
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)
870+
data_add["_rb_batch_size"] = torch.as_tensor(data.batch_size)
871871

872872
else:
873873
data_add = data
@@ -1441,7 +1441,7 @@ def __getitem__(
14411441
if isinstance(index, slice) and index == slice(None):
14421442
return self
14431443
if isinstance(index, (list, range, np.ndarray)):
1444-
index = torch.tensor(index)
1444+
index = torch.as_tensor(index)
14451445
if isinstance(index, torch.Tensor):
14461446
if index.ndim > 1:
14471447
raise RuntimeError(

torchrl/data/replay_buffers/samplers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,10 @@ def dumps(self, path):
461461
filename=path / "mintree.memmap",
462462
)
463463
mm_st.copy_(
464-
torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)])
464+
torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)])
465465
)
466466
mm_mt.copy_(
467-
torch.tensor([self._min_tree[i] for i in range(self._max_capacity)])
467+
torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)])
468468
)
469469
with open(path / "sampler_metadata.json", "w") as file:
470470
json.dump(

torchrl/data/replay_buffers/storages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ def __getitem__(self, index):
10051005
if isinstance(index, slice) and index == slice(None):
10061006
return self
10071007
if isinstance(index, (list, range, np.ndarray)):
1008-
index = torch.tensor(index)
1008+
index = torch.as_tensor(index)
10091009
if isinstance(index, torch.Tensor):
10101010
if index.ndim > 1:
10111011
raise RuntimeError(

torchrl/data/replay_buffers/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def _to_torch(
2828
data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
2929
) -> torch.Tensor:
3030
if isinstance(data, np.generic):
31-
return torch.tensor(data, device=device)
31+
return torch.as_tensor(data, device=device)
3232
elif isinstance(data, np.ndarray):
3333
data = torch.from_numpy(data)
3434
elif not isinstance(data, Tensor):
35-
data = torch.tensor(data, device=device)
35+
data = torch.as_tensor(data, device=device)
3636

3737
if pin_memory:
3838
data = data.pin_memory()

torchrl/data/replay_buffers/writers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def __getstate__(self):
357357
def dumps(self, path):
358358
path = Path(path).absolute()
359359
path.mkdir(exist_ok=True)
360-
t = torch.tensor(self._current_top_values)
360+
t = torch.as_tensor(self._current_top_values)
361361
try:
362362
MemoryMappedTensor.from_filename(
363363
filename=path / "current_top_values.memmap",
@@ -453,7 +453,7 @@ def __getitem__(self, index):
453453
if isinstance(index, slice) and index == slice(None):
454454
return self
455455
if isinstance(index, (list, range, np.ndarray)):
456-
index = torch.tensor(index)
456+
index = torch.as_tensor(index)
457457
if isinstance(index, torch.Tensor):
458458
if index.ndim > 1:
459459
raise RuntimeError(

torchrl/data/rlhf/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def update(self, kl_values: Sequence[float]):
100100
)
101101
n_steps = len(kl_values)
102102
# renormalize kls
103-
kl_value = -torch.tensor(kl_values).mean() / self.coef
103+
kl_value = -torch.as_tensor(kl_values).mean() / self.coef
104104
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
105105
mult = 1 + proportional_error * n_steps / self.horizon
106106
self.coef *= mult # βₜ₊₁
@@ -314,10 +314,10 @@ def _get_done_status(self, generated, batch):
314314
# of generated tokens
315315
done_idx = torch.minimum(
316316
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
317-
torch.tensor(self.max_new_tokens) - 1,
317+
torch.as_tensor(self.max_new_tokens) - 1,
318318
)
319319
truncated_idx = (
320-
torch.tensor(self.max_new_tokens, device=generated.device).expand_as(
320+
torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
321321
done_idx
322322
)
323323
- 1

torchrl/data/tensor_specs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1374,9 +1374,9 @@ def encode(
13741374
) -> torch.Tensor:
13751375
if not isinstance(val, torch.Tensor):
13761376
if ignore_device:
1377-
val = torch.tensor(val)
1377+
val = torch.as_tensor(val)
13781378
else:
1379-
val = torch.tensor(val, device=self.device)
1379+
val = torch.as_tensor(val, device=self.device)
13801380

13811381
if space is None:
13821382
space = self.space
@@ -1555,9 +1555,9 @@ def __init__(
15551555
dtype = torch.get_default_dtype()
15561556

15571557
if not isinstance(low, torch.Tensor):
1558-
low = torch.tensor(low, dtype=dtype, device=device)
1558+
low = torch.as_tensor(low, dtype=dtype, device=device)
15591559
if not isinstance(high, torch.Tensor):
1560-
high = torch.tensor(high, dtype=dtype, device=device)
1560+
high = torch.as_tensor(high, dtype=dtype, device=device)
15611561
if high.device != device:
15621562
high = high.to(device)
15631563
if low.device != device:
@@ -1857,8 +1857,8 @@ def __init__(
18571857
dtype, device = _default_dtype_and_device(dtype, device)
18581858
box = (
18591859
ContinuousBox(
1860-
torch.tensor(-np.inf, device=device).expand(shape),
1861-
torch.tensor(np.inf, device=device).expand(shape),
1860+
torch.as_tensor(-np.inf, device=device).expand(shape),
1861+
torch.as_tensor(np.inf, device=device).expand(shape),
18621862
)
18631863
if shape == _DEFAULT_SHAPE
18641864
else None

torchrl/envs/libs/dm_control.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]:
102102

103103
def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor:
104104
if isinstance(array, np.ndarray):
105-
return torch.tensor(array.copy())
105+
return torch.as_tensor(array.copy())
106106
else:
107-
return torch.tensor(array)
107+
return torch.as_tensor(array)
108108

109109

110110
class DMControlWrapper(GymLikeEnv):

torchrl/envs/libs/envpool.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _transform_step_output(
264264
f"The output of step was had {len(out)} elements, but only 4 or 5 are supported."
265265
)
266266
obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
267-
reward_and_done = {self.reward_key: torch.tensor(reward)}
267+
reward_and_done = {self.reward_key: torch.as_tensor(reward)}
268268
reward_and_done["done"] = done
269269
reward_and_done["terminated"] = terminated
270270
reward_and_done["truncated"] = truncated
@@ -290,7 +290,7 @@ def _treevalue_or_numpy_to_tensor_or_dict(
290290
if isinstance(x, treevalue.TreeValue):
291291
ret = self._treevalue_to_dict(x)
292292
elif not isinstance(x, dict):
293-
ret = {"observation": torch.tensor(x)}
293+
ret = {"observation": torch.as_tensor(x)}
294294
else:
295295
ret = x
296296
return ret
@@ -304,7 +304,7 @@ def _treevalue_to_dict(
304304
"""
305305
import treevalue
306306

307-
return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)}
307+
return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)}
308308

309309
def _set_seed(self, seed: Optional[int]):
310310
if seed is not None:

torchrl/envs/libs/gym.py

+1
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,7 @@ def _read_obs(self, obs, key, tensor, index):
15061506
def __call__(self, info_dict, tensordict):
15071507
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
15081508
for key, item in self.info_spec.items(True, True):
1509+
key = (key,) if isinstance(key, str) else key
15091510
final_obs_buffer = item.zero()
15101511
if terminal_obs is not None:
15111512
for i, obs in enumerate(terminal_obs):

torchrl/envs/libs/pettingzoo.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def _init_env(self):
462462
"info": CompositeSpec(
463463
{
464464
key: UnboundedContinuousTensorSpec(
465-
shape=torch.tensor(value).shape,
465+
shape=torch.as_tensor(value).shape,
466466
device=self.device,
467467
)
468468
for key, value in info_dict[agent].items()
@@ -501,7 +501,7 @@ def _init_env(self):
501501
device=self.device,
502502
)
503503
except AttributeError:
504-
state_example = torch.tensor(self.state(), device=self.device)
504+
state_example = torch.as_tensor(self.state(), device=self.device)
505505
state_spec = UnboundedContinuousTensorSpec(
506506
shape=state_example.shape,
507507
dtype=state_example.dtype,
@@ -560,7 +560,7 @@ def _reset(
560560
if group_info is not None:
561561
agent_info_dict = info_dict[agent]
562562
for agent_info, value in agent_info_dict.items():
563-
group_info.get(agent_info)[index] = torch.tensor(
563+
group_info.get(agent_info)[index] = torch.as_tensor(
564564
value, device=self.device
565565
)
566566

torchrl/envs/transforms/gym_transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _get_lives(self):
135135
if callable(lives):
136136
lives = lives()
137137
elif isinstance(lives, list) and all(callable(_lives) for _lives in lives):
138-
lives = torch.tensor([_lives() for _lives in lives])
138+
lives = torch.as_tensor([_lives() for _lives in lives])
139139
return lives
140140

141141
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -170,7 +170,7 @@ def _reset(self, tensordict, tensordict_reset):
170170
end_of_life = False
171171
tensordict_reset.set(
172172
self.eol_key,
173-
torch.tensor(end_of_life).expand(
173+
torch.as_tensor(end_of_life).expand(
174174
parent.full_done_spec[self.done_key].shape
175175
),
176176
)

torchrl/envs/transforms/r3m.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ def _init(self):
292292
std = [0.229, 0.224, 0.225]
293293
normalize = ObservationNorm(
294294
in_keys=in_keys,
295-
loc=torch.tensor(mean).view(3, 1, 1),
296-
scale=torch.tensor(std).view(3, 1, 1),
295+
loc=torch.as_tensor(mean).view(3, 1, 1),
296+
scale=torch.as_tensor(std).view(3, 1, 1),
297297
standard_normal=True,
298298
)
299299
transforms.append(normalize)

torchrl/envs/transforms/rlhf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def find_sample_log_prob(module):
146146
self.functional_actor.apply(find_sample_log_prob)
147147

148148
if not isinstance(coef, torch.Tensor):
149-
coef = torch.tensor(coef)
149+
coef = torch.as_tensor(coef)
150150
self.register_buffer("coef", coef)
151151

152152
def _reset(

torchrl/envs/transforms/transforms.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ def check_val(val):
13321332
if val is None:
13331333
return None, None, torch.finfo(torch.get_default_dtype()).max
13341334
if not isinstance(val, torch.Tensor):
1335-
val = torch.tensor(val)
1335+
val = torch.as_tensor(val)
13361336
if not val.dtype.is_floating_point:
13371337
val = val.float()
13381338
eps = torch.finfo(val.dtype).resolution
@@ -1626,10 +1626,10 @@ def __init__(
16261626
out_keys = copy(in_keys)
16271627
super().__init__(in_keys=in_keys, out_keys=out_keys)
16281628
clamp_min_tensor = (
1629-
clamp_min if isinstance(clamp_min, Tensor) else torch.tensor(clamp_min)
1629+
clamp_min if isinstance(clamp_min, Tensor) else torch.as_tensor(clamp_min)
16301630
)
16311631
clamp_max_tensor = (
1632-
clamp_max if isinstance(clamp_max, Tensor) else torch.tensor(clamp_max)
1632+
clamp_max if isinstance(clamp_max, Tensor) else torch.as_tensor(clamp_max)
16331633
)
16341634
self.register_buffer("clamp_min", clamp_min_tensor)
16351635
self.register_buffer("clamp_max", clamp_max_tensor)
@@ -2396,7 +2396,7 @@ def __init__(
23962396
out_keys_inv=out_keys_inv,
23972397
)
23982398
if not isinstance(standard_normal, torch.Tensor):
2399-
standard_normal = torch.tensor(standard_normal)
2399+
standard_normal = torch.as_tensor(standard_normal)
24002400
self.register_buffer("standard_normal", standard_normal)
24012401
self.eps = 1e-6
24022402

torchrl/envs/transforms/vc1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def _map_tv_to_torchrl(
132132
elif isinstance(model_transforms, transforms.Normalize):
133133
return ObservationNorm(
134134
in_keys=in_keys,
135-
loc=torch.tensor(model_transforms.mean).reshape(3, 1, 1),
136-
scale=torch.tensor(model_transforms.std).reshape(3, 1, 1),
135+
loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1),
136+
scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1),
137137
standard_normal=True,
138138
)
139139
elif isinstance(model_transforms, transforms.ToTensor):

torchrl/envs/transforms/vip.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def _init(self):
266266
std = [0.229, 0.224, 0.225]
267267
normalize = ObservationNorm(
268268
in_keys=in_keys,
269-
loc=torch.tensor(mean).view(3, 1, 1),
270-
scale=torch.tensor(std).view(3, 1, 1),
269+
loc=torch.as_tensor(mean).view(3, 1, 1),
270+
scale=torch.as_tensor(std).view(3, 1, 1),
271271
standard_normal=True,
272272
)
273273
transforms.append(normalize)

torchrl/modules/distributions/continuous.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ def __init__(
240240
if isinstance(max, torch.Tensor):
241241
max = max.to(self.device)
242242
else:
243-
max = torch.tensor(max, device=self.device)
243+
max = torch.as_tensor(max, device=self.device)
244244
if isinstance(min, torch.Tensor):
245245
min = min.to(self.device)
246246
else:
247-
min = torch.tensor(min, device=self.device)
247+
min = torch.as_tensor(min, device=self.device)
248248
self.min = min
249249
self.max = max
250250
self.update(loc, scale)

torchrl/modules/models/exploration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ def __init__(
345345
)
346346

347347
if sigma_init != 0.0:
348-
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
348+
self.register_buffer(
349+
"sigma_init", torch.as_tensor(sigma_init, device=device)
350+
)
349351

350352
@property
351353
def sigma(self):

torchrl/modules/planners/mppi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
self.num_candidates = num_candidates
146146
self.top_k = top_k
147147
self.reward_key = reward_key
148-
self.register_buffer("temperature", torch.tensor(temperature))
148+
self.register_buffer("temperature", torch.as_tensor(temperature))
149149

150150
def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
151151
batch_size = tensordict.batch_size

0 commit comments

Comments
 (0)