Skip to content

Commit 2d86afc

Browse files
committed
Update
[ghstack-poisoned]
2 parents f16655f + fcc23d7 commit 2d86afc

File tree

12 files changed

+114
-62
lines changed

12 files changed

+114
-62
lines changed

Diff for: test/test_env.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,34 @@ def test_parallel_env_device(
16921692
env_serial.close(raise_if_closed=False)
16931693
env0.close(raise_if_closed=False)
16941694

1695+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1696+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1697+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1698+
def make_env() -> GymEnv:
1699+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1700+
return env.append_transform(DoubleToFloat())
1701+
1702+
# Rollouts work with a regular env
1703+
parallel_env = maybe_fork_ParallelEnv(
1704+
num_workers=1, create_env_fn=make_env, device=None
1705+
)
1706+
parallel_env.reset()
1707+
parallel_env.set_seed(0)
1708+
torch.manual_seed(0)
1709+
1710+
parallel_rollout = parallel_env.rollout(max_steps=10)
1711+
1712+
# Rollout doesn't work with Parallelnv
1713+
parallel_env = maybe_fork_ParallelEnv(
1714+
num_workers=1, create_env_fn=make_env, device="cpu"
1715+
)
1716+
parallel_env.reset()
1717+
parallel_env.set_seed(0)
1718+
torch.manual_seed(0)
1719+
1720+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1721+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1722+
16951723
@pytest.mark.skipif(not _has_gym, reason="no gym")
16961724
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16971725
@pytest.mark.parametrize(
@@ -4907,7 +4935,6 @@ def policy(td):
49074935
if assign_done:
49084936
assert "terminated" in r
49094937
assert "done" in r
4910-
print(r)
49114938

49124939

49134940
if __name__ == "__main__":

Diff for: test/test_storage_map.py

+11
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

Diff for: torchrl/_utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919
from contextlib import nullcontext
2020
from copy import copy
21-
from distutils.util import strtobool
2221
from functools import wraps
2322
from importlib import import_module
2423
from typing import Any, Callable, cast, TypeVar
@@ -35,6 +34,21 @@
3534
except ImportError:
3635
from torch._dynamo import is_compiling
3736

37+
38+
def strtobool(val: Any) -> bool:
39+
"""Convert a string representation of truth to a boolean.
40+
41+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
42+
Raises ValueError if 'val' is anything else.
43+
"""
44+
val = val.lower()
45+
if val in ("y", "yes", "t", "true", "on", "1"):
46+
return True
47+
if val in ("n", "no", "f", "false", "off", "0"):
48+
return False
49+
raise ValueError(f"Invalid truth value {val!r}")
50+
51+
3852
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
3953
logger = logging.getLogger("torchrl")
4054
logger.setLevel(getattr(logging, LOGGING_LEVEL))

Diff for: torchrl/data/llm/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
14+
from .utils import (
15+
AdaptiveKLController,
16+
ConstantKLController,
17+
LLMData,
18+
LLMInput,
19+
LLMOutput,
20+
RolloutFromModel,
21+
)
1522

1623
__all__ = [
1724
"AdaptiveKLController",

Diff for: torchrl/data/llm/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,10 @@ def step_scheduler(self):
543543
while len(self._kl_queue):
544544
self._kl_queue.remove(self._kl_queue[0])
545545

546+
546547
LLMInpOut = TypeVar("LLMInpOut")
547548

549+
548550
class LLMInput(TensorClass["nocast"]):
549551
"""Represents the input to a Large Language Model (LLM).
550552
@@ -557,11 +559,13 @@ class LLMInput(TensorClass["nocast"]):
557559
.. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`.
558560
559561
"""
562+
560563
tokens: torch.Tensor
561564
attention_mask: torch.Tensor | None = None
562565
token_list: list[int] | list[list[int]] | None = None
563566
text: str | list[str] | None = None
564567

568+
565569
class LLMOutput(TensorClass["nocast"]):
566570
"""Represents the output from a Large Language Model (LLM).
567571
@@ -581,6 +585,7 @@ class LLMOutput(TensorClass["nocast"]):
581585
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`.
582586
583587
"""
588+
584589
tokens: torch.Tensor
585590
tokens_response: torch.Tensor | None = None
586591
token_list: list[int] | list[list[int]] | None = None
@@ -594,6 +599,7 @@ def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut:
594599
# placeholder
595600
raise NotImplementedError
596601

602+
597603
class LLMData(TensorClass["nocast"]):
598604
"""Represents the input or output of a Large Language Model (LLM).
599605
@@ -619,6 +625,7 @@ class LLMData(TensorClass["nocast"]):
619625
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`.
620626
621627
"""
628+
622629
tokens: torch.Tensor
623630
tokens_response: torch.Tensor | None = None
624631
attention_mask: torch.Tensor | None = None

Diff for: torchrl/data/map/tree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def make_node(
122122
return cls(
123123
count=torch.zeros(()),
124124
wins=torch.zeros(()),
125-
node=data.exclude("action", "next"),
125+
node_data=data.exclude("action", "next"),
126126
rollout=rollout,
127127
subtree=subtree,
128128
device=device,

Diff for: torchrl/data/postprocs/postprocs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch import nn
1313

1414

15-
1615
def _get_reward(
1716
gamma: float,
1817
reward: torch.Tensor,
@@ -367,6 +366,7 @@ def __init__(
367366
discount: float = 1.0,
368367
):
369368
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:

Diff for: torchrl/envs/batched_envs.py

+19-35
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,14 @@ def __init__(
379379

380380
is_spec_locked = EnvBase.is_spec_locked
381381

382+
def select_and_clone(self, name, tensor, selected_keys=None):
383+
if selected_keys is None:
384+
selected_keys = self._selected_step_keys
385+
if name in selected_keys:
386+
if self.device is not None and tensor.device != self.device:
387+
return tensor.to(self.device, non_blocking=self.non_blocking)
388+
return tensor.clone()
389+
382390
@property
383391
def non_blocking(self):
384392
nb = self._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10721080
selected_output_keys = self._selected_reset_keys_filt
10731081

10741082
# select + clone creates 2 tds, but we can create one only
1075-
def select_and_clone(name, tensor):
1076-
if name in selected_output_keys:
1077-
return tensor.clone()
1078-
10791083
out = self.shared_tensordict_parent.named_apply(
1080-
select_and_clone,
1084+
lambda *args: self.select_and_clone(
1085+
*args, selected_keys=selected_output_keys
1086+
),
10811087
nested_keys=True,
10821088
filter_empty=True,
10831089
)
@@ -1150,14 +1156,14 @@ def _step(
11501156
# will be modified in-place at further steps
11511157
device = self.device
11521158

1153-
def select_and_clone(name, tensor):
1154-
if name in self._selected_step_keys:
1155-
return tensor.clone()
1159+
selected_keys = self._selected_step_keys
11561160

11571161
if partial_steps is not None:
11581162
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
11591163
out = next_td.named_apply(
1160-
select_and_clone, nested_keys=True, filter_empty=True
1164+
lambda *args: self.select_and_clone(*args, selected_keys),
1165+
nested_keys=True,
1166+
filter_empty=True,
11611167
)
11621168
if out_tds is not None:
11631169
out.update(
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
20102016
next_td = shared_tensordict_parent.get("next")
20112017
device = self.device
20122018

2013-
if next_td.device != device and device is not None:
2014-
2015-
def select_and_clone(name, tensor):
2016-
if name in self._selected_step_keys:
2017-
return tensor.to(device, non_blocking=self.non_blocking)
2018-
2019-
else:
2020-
2021-
def select_and_clone(name, tensor):
2022-
if name in self._selected_step_keys:
2023-
return tensor.clone()
2024-
20252019
out = next_td.named_apply(
2026-
select_and_clone,
2020+
self.select_and_clone,
20272021
nested_keys=True,
20282022
filter_empty=True,
20292023
device=device,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
22032197
selected_output_keys = self._selected_reset_keys_filt
22042198
device = self.device
22052199

2206-
if self.shared_tensordict_parent.device != device and device is not None:
2207-
2208-
def select_and_clone(name, tensor):
2209-
if name in selected_output_keys:
2210-
return tensor.to(device, non_blocking=self.non_blocking)
2211-
2212-
else:
2213-
2214-
def select_and_clone(name, tensor):
2215-
if name in selected_output_keys:
2216-
return tensor.clone()
2217-
22182200
out = self.shared_tensordict_parent.named_apply(
2219-
select_and_clone,
2201+
lambda *args: self.select_and_clone(
2202+
*args, selected_keys=selected_output_keys
2203+
),
22202204
nested_keys=True,
22212205
filter_empty=True,
22222206
device=device,

Diff for: torchrl/envs/custom/llm.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -80,41 +80,37 @@ def __init__(
8080
self._batch_locked = False
8181
else:
8282
self._batch_locked = True
83-
super().__init__(device=device, batch_size=() if batch_size is None else (batch_size,))
83+
super().__init__(
84+
device=device, batch_size=() if batch_size is None else (batch_size,)
85+
)
8486
self.str2str = str2str
8587
self.vocab_size = vocab_size
8688
self.observation_key = unravel_key(token_key)
87-
self.attention_key = unravel_key(attention_key)
89+
if attention_key is not None:
90+
attention_key = unravel_key(attention_key)
91+
self.attention_key = attention_key
8892
self.no_stack = no_stack
8993
self.assign_reward = assign_reward
9094
self.assign_done = assign_done
9195

9296
# self.action_key = unravel_key(action_key)
9397
if str2str:
9498
self.full_observation_spec_unbatched = Composite(
95-
{
96-
token_key: NonTensor(
97-
example_data="a string", batched=True, shape=()
98-
)
99-
}
99+
{token_key: NonTensor(example_data="a string", batched=True, shape=())}
100100
)
101101
self.full_action_spec_unbatched = Composite(
102102
{action_key: NonTensor(example_data="a string", batched=True, shape=())}
103103
)
104104
else:
105105
if vocab_size is None:
106106
observation_spec = {
107-
token_key: Unbounded(
108-
shape=(-1,), dtype=torch.int64, device=device
109-
)
110-
}
107+
token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device)
108+
}
111109
if attention_key is not None:
112110
observation_spec[attention_key] = Unbounded(
113-
shape=(-1,), dtype=torch.int64, device=device
114-
)
115-
self.full_observation_spec_unbatched = Composite(
116-
observation_spec
117-
)
111+
shape=(-1,), dtype=torch.int64, device=device
112+
)
113+
self.full_observation_spec_unbatched = Composite(observation_spec)
118114
self.full_action_spec_unbatched = Composite(
119115
{
120116
action_key: Unbounded(
@@ -392,7 +388,6 @@ def _make_next_obs(
392388

393389
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
394390
# We should have an observation by this time, if not raise an exception
395-
print('tensordict', tensordict)
396391
if tensordict is None or self.observation_key not in tensordict.keys(
397392
isinstance(self.observation_key, tuple)
398393
):

Diff for: torchrl/envs/libs/unity_mlagents.py

-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def _collect_agents(self, env):
132132
for steps_idx in [0, 1]:
133133
for behavior in env.behavior_specs.keys():
134134
steps = env.get_steps(behavior)[steps_idx]
135-
is_terminal = steps_idx == 1
136135
agent_ids = steps.agent_id
137136
group_ids = steps.group_id
138137

Diff for: torchrl/envs/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,6 @@ def _update_during_reset(
14071407
if not reset_keys:
14081408
return tensordict.update(tensordict_reset)
14091409
roots = set()
1410-
print("reset_keys", reset_keys)
14111410
for reset_key in reset_keys:
14121411
# get the node of the reset key
14131412
if isinstance(reset_key, tuple):
@@ -1423,7 +1422,6 @@ def _update_during_reset(
14231422
reset_key_tuple = (reset_key,)
14241423
# get the reset signal
14251424
reset = tensordict.pop(reset_key, None)
1426-
print("reset popped", reset)
14271425

14281426
# check if this reset should be ignored -- this happens whenever the
14291427
# root node has already been updated

Diff for: torchrl/objectives/value/advantages.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,18 @@ def __init__(
12811281
skip_existing=skip_existing,
12821282
device=device,
12831283
)
1284-
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1285-
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1284+
self.register_buffer(
1285+
"gamma",
1286+
gamma.to(self._device)
1287+
if isinstance(gamma, Tensor)
1288+
else torch.tensor(gamma, device=self._device),
1289+
)
1290+
self.register_buffer(
1291+
"lmbda",
1292+
lmbda.to(self._device)
1293+
if isinstance(lmbda, Tensor)
1294+
else torch.tensor(lmbda, device=self._device),
1295+
)
12861296
self.average_gae = average_gae
12871297
self.vectorized = vectorized
12881298
self.time_dim = time_dim

0 commit comments

Comments
 (0)