Skip to content

Commit 59e8545

Browse files
committed
[Feature] TensorDictPrimer with single default_value callable
ghstack-source-id: a9a677f24fc1e6a47312d0a96ab60daae543ff78 Pull Request resolved: #2732
1 parent 8c9dc05 commit 59e8545

File tree

3 files changed

+141
-75
lines changed

3 files changed

+141
-75
lines changed

Diff for: torchrl/envs/custom/pendulum.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269269
batch_size = (
270270
tensordict.batch_size if tensordict is not None else self.batch_size
271271
)
272-
if tensordict is None or tensordict.is_empty():
272+
if tensordict is None or "params" not in tensordict:
273273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275
# parameters to get started.
276276
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
277+
elif "th" in tensordict and "thdot" in tensordict:
278+
# we can hard-reset the env too
279+
return tensordict
280+
out = self._reset_random_data(
281+
tensordict.shape, batch_size, tensordict["params"]
282+
)
283+
return out
284+
285+
def _reset_random_data(self, shape, batch_size, params):
277286

278287
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
279288
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284293
# of simulators run simultaneously. In other contexts, the initial
285294
# random state's shape will depend upon the environment batch-size instead.
286295
th = (
287-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
296+
torch.rand(shape, generator=self.rng, device=self.device)
288297
* (high_th - low_th)
289298
+ low_th
290299
)
291300
thdot = (
292-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
301+
torch.rand(shape, generator=self.rng, device=self.device)
293302
* (high_thdot - low_thdot)
294303
+ low_thdot
295304
)
296305
out = TensorDict(
297306
{
298307
"th": th,
299308
"thdot": thdot,
300-
"params": tensordict["params"],
309+
"params": params,
301310
},
302311
batch_size=batch_size,
303312
)

Diff for: torchrl/envs/transforms/rlhf.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
from copy import copy, deepcopy
68

79
import torch

0 commit comments

Comments
 (0)