Skip to content

Commit d36a97e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 249733f + 1a6a529 commit d36a97e

File tree

8 files changed

+101
-40
lines changed

8 files changed

+101
-40
lines changed

test/test_env.py

+28
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,34 @@ def test_parallel_env_device(
16921692
env_serial.close(raise_if_closed=False)
16931693
env0.close(raise_if_closed=False)
16941694

1695+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1696+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1697+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1698+
def make_env() -> GymEnv:
1699+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1700+
return env.append_transform(DoubleToFloat())
1701+
1702+
# Rollouts work with a regular env
1703+
parallel_env = maybe_fork_ParallelEnv(
1704+
num_workers=1, create_env_fn=make_env, device=None
1705+
)
1706+
parallel_env.reset()
1707+
parallel_env.set_seed(0)
1708+
torch.manual_seed(0)
1709+
1710+
parallel_rollout = parallel_env.rollout(max_steps=10)
1711+
1712+
# Rollout doesn't work with Parallelnv
1713+
parallel_env = maybe_fork_ParallelEnv(
1714+
num_workers=1, create_env_fn=make_env, device="cpu"
1715+
)
1716+
parallel_env.reset()
1717+
parallel_env.set_seed(0)
1718+
torch.manual_seed(0)
1719+
1720+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1721+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1722+
16951723
@pytest.mark.skipif(not _has_gym, reason="no gym")
16961724
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16971725
@pytest.mark.parametrize(

test/test_storage_map.py

+11
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

torchrl/_utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919
from contextlib import nullcontext
2020
from copy import copy
21-
from distutils.util import strtobool
2221
from functools import wraps
2322
from importlib import import_module
2423
from typing import Any, Callable, cast, TypeVar
@@ -35,6 +34,21 @@
3534
except ImportError:
3635
from torch._dynamo import is_compiling
3736

37+
38+
def strtobool(val: Any) -> bool:
39+
"""Convert a string representation of truth to a boolean.
40+
41+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
42+
Raises ValueError if 'val' is anything else.
43+
"""
44+
val = val.lower()
45+
if val in ("y", "yes", "t", "true", "on", "1"):
46+
return True
47+
if val in ("n", "no", "f", "false", "off", "0"):
48+
return False
49+
raise ValueError(f"Invalid truth value {val!r}")
50+
51+
3852
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
3953
logger = logging.getLogger("torchrl")
4054
logger.setLevel(getattr(logging, LOGGING_LEVEL))

torchrl/data/llm/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
14+
from .utils import (
15+
AdaptiveKLController,
16+
ConstantKLController,
17+
LLMData,
18+
LLMInput,
19+
LLMOutput,
20+
RolloutFromModel,
21+
)
1522

1623
__all__ = [
1724
"AdaptiveKLController",

torchrl/data/llm/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,10 @@ def step_scheduler(self):
543543
while len(self._kl_queue):
544544
self._kl_queue.remove(self._kl_queue[0])
545545

546+
546547
LLMInpOut = TypeVar("LLMInpOut")
547548

549+
548550
class LLMInput(TensorClass["nocast"]):
549551
"""Represents the input to a Large Language Model (LLM).
550552
@@ -557,11 +559,13 @@ class LLMInput(TensorClass["nocast"]):
557559
.. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`.
558560
559561
"""
562+
560563
tokens: torch.Tensor
561564
attention_mask: torch.Tensor | None = None
562565
token_list: list[int] | list[list[int]] | None = None
563566
text: str | list[str] | None = None
564567

568+
565569
class LLMOutput(TensorClass["nocast"]):
566570
"""Represents the output from a Large Language Model (LLM).
567571
@@ -581,6 +585,7 @@ class LLMOutput(TensorClass["nocast"]):
581585
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`.
582586
583587
"""
588+
584589
tokens: torch.Tensor
585590
tokens_response: torch.Tensor | None = None
586591
token_list: list[int] | list[list[int]] | None = None
@@ -594,6 +599,7 @@ def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut:
594599
# placeholder
595600
raise NotImplementedError
596601

602+
597603
class LLMData(TensorClass["nocast"]):
598604
"""Represents the input or output of a Large Language Model (LLM).
599605
@@ -619,6 +625,7 @@ class LLMData(TensorClass["nocast"]):
619625
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`.
620626
621627
"""
628+
622629
tokens: torch.Tensor
623630
tokens_response: torch.Tensor | None = None
624631
attention_mask: torch.Tensor | None = None

torchrl/data/map/tree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def make_node(
122122
return cls(
123123
count=torch.zeros(()),
124124
wins=torch.zeros(()),
125-
node=data.exclude("action", "next"),
125+
node_data=data.exclude("action", "next"),
126126
rollout=rollout,
127127
subtree=subtree,
128128
device=device,

torchrl/envs/batched_envs.py

+19-35
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,14 @@ def __init__(
379379

380380
is_spec_locked = EnvBase.is_spec_locked
381381

382+
def select_and_clone(self, name, tensor, selected_keys=None):
383+
if selected_keys is None:
384+
selected_keys = self._selected_step_keys
385+
if name in selected_keys:
386+
if self.device is not None and tensor.device != self.device:
387+
return tensor.to(self.device, non_blocking=self.non_blocking)
388+
return tensor.clone()
389+
382390
@property
383391
def non_blocking(self):
384392
nb = self._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10721080
selected_output_keys = self._selected_reset_keys_filt
10731081

10741082
# select + clone creates 2 tds, but we can create one only
1075-
def select_and_clone(name, tensor):
1076-
if name in selected_output_keys:
1077-
return tensor.clone()
1078-
10791083
out = self.shared_tensordict_parent.named_apply(
1080-
select_and_clone,
1084+
lambda *args: self.select_and_clone(
1085+
*args, selected_keys=selected_output_keys
1086+
),
10811087
nested_keys=True,
10821088
filter_empty=True,
10831089
)
@@ -1150,14 +1156,14 @@ def _step(
11501156
# will be modified in-place at further steps
11511157
device = self.device
11521158

1153-
def select_and_clone(name, tensor):
1154-
if name in self._selected_step_keys:
1155-
return tensor.clone()
1159+
selected_keys = self._selected_step_keys
11561160

11571161
if partial_steps is not None:
11581162
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
11591163
out = next_td.named_apply(
1160-
select_and_clone, nested_keys=True, filter_empty=True
1164+
lambda *args: self.select_and_clone(*args, selected_keys),
1165+
nested_keys=True,
1166+
filter_empty=True,
11611167
)
11621168
if out_tds is not None:
11631169
out.update(
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
20102016
next_td = shared_tensordict_parent.get("next")
20112017
device = self.device
20122018

2013-
if next_td.device != device and device is not None:
2014-
2015-
def select_and_clone(name, tensor):
2016-
if name in self._selected_step_keys:
2017-
return tensor.to(device, non_blocking=self.non_blocking)
2018-
2019-
else:
2020-
2021-
def select_and_clone(name, tensor):
2022-
if name in self._selected_step_keys:
2023-
return tensor.clone()
2024-
20252019
out = next_td.named_apply(
2026-
select_and_clone,
2020+
self.select_and_clone,
20272021
nested_keys=True,
20282022
filter_empty=True,
20292023
device=device,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
22032197
selected_output_keys = self._selected_reset_keys_filt
22042198
device = self.device
22052199

2206-
if self.shared_tensordict_parent.device != device and device is not None:
2207-
2208-
def select_and_clone(name, tensor):
2209-
if name in selected_output_keys:
2210-
return tensor.to(device, non_blocking=self.non_blocking)
2211-
2212-
else:
2213-
2214-
def select_and_clone(name, tensor):
2215-
if name in selected_output_keys:
2216-
return tensor.clone()
2217-
22182200
out = self.shared_tensordict_parent.named_apply(
2219-
select_and_clone,
2201+
lambda *args: self.select_and_clone(
2202+
*args, selected_keys=selected_output_keys
2203+
),
22202204
nested_keys=True,
22212205
filter_empty=True,
22222206
device=device,

torchrl/objectives/value/advantages.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,18 @@ def __init__(
12811281
skip_existing=skip_existing,
12821282
device=device,
12831283
)
1284-
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1285-
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1284+
self.register_buffer(
1285+
"gamma",
1286+
gamma.to(self._device)
1287+
if isinstance(gamma, Tensor)
1288+
else torch.tensor(gamma, device=self._device),
1289+
)
1290+
self.register_buffer(
1291+
"lmbda",
1292+
lmbda.to(self._device)
1293+
if isinstance(lmbda, Tensor)
1294+
else torch.tensor(lmbda, device=self._device),
1295+
)
12861296
self.average_gae = average_gae
12871297
self.vectorized = vectorized
12881298
self.time_dim = time_dim

0 commit comments

Comments
 (0)