Skip to content

Commit 9ba8cc6

Browse files
committed
[Doc] Add docstring for MCTSForest.extend
ghstack-source-id: dbef5e4 Pull Request resolved: pytorch#2795
1 parent 0f75ae2 commit 9ba8cc6

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

torchrl/data/map/tree.py

+127
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,133 @@ def _make_node_map(self, source, dest):
10401040
self.max_size = self.data_map.max_size
10411041

10421042
def extend(self, rollout, *, return_node: bool = False):
1043+
"""Add a rollout to the forest.
1044+
1045+
Nodes are only added to a tree at points where rollouts diverge from
1046+
each other and at the endpoints of rollouts.
1047+
1048+
If there is no existing tree that matches the first steps of the
1049+
rollout, a new tree is added. Only one node is created, for the final
1050+
step.
1051+
1052+
If there is an existing tree that matches, the rollout is added to that
1053+
tree. If the rollout diverges from all other rollouts in the tree at
1054+
some step, a new node is created before the step where the rollouts
1055+
diverge, and a leaf node is created for the final step of the rollout.
1056+
If all of the rollout's steps match with a previously added rollout,
1057+
nothing changes. If the rollout matches up to a leaf node of a tree but
1058+
continues beyond it, that node is extended to the end of the rollout,
1059+
and no new nodes are created.
1060+
1061+
Args:
1062+
rollout (TensorDict): The rollout to add to the forest.
1063+
return_node (bool, optional): If True, the method returns the added
1064+
node. Default is ``False``.
1065+
1066+
Returns:
1067+
Tree: The node that was added to the forest. This is only
1068+
returned if ``return_node`` is True.
1069+
1070+
Examples:
1071+
>>> from torchrl.data import MCTSForest
1072+
>>> from tensordict import TensorDict
1073+
>>> import torch
1074+
>>> forest = MCTSForest()
1075+
>>> r0 = TensorDict({
1076+
... 'action': torch.tensor([1, 2, 3, 4, 5]),
1077+
... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
1078+
... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
1079+
... }, [5])
1080+
>>> r1 = TensorDict({
1081+
... 'action': torch.tensor([1, 2, 6, 7]),
1082+
... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
1083+
... 'observation': torch.tensor([ 0, 123, 392, 235])
1084+
... }, [4])
1085+
>>> td_root = r0[0].exclude("next")
1086+
>>> forest.extend(r0)
1087+
>>> forest.extend(r1)
1088+
>>> tree = forest.get_tree(td_root)
1089+
>>> print(tree)
1090+
Tree(
1091+
count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
1092+
index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1093+
node_data=TensorDict(
1094+
fields={
1095+
observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
1096+
batch_size=torch.Size([]),
1097+
device=cpu,
1098+
is_shared=False),
1099+
node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
1100+
rollout=TensorDict(
1101+
fields={
1102+
action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1103+
next: TensorDict(
1104+
fields={
1105+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1106+
batch_size=torch.Size([2]),
1107+
device=cpu,
1108+
is_shared=False),
1109+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1110+
batch_size=torch.Size([2]),
1111+
device=cpu,
1112+
is_shared=False),
1113+
subtree=Tree(
1114+
_parent=NonTensorStack(
1115+
[<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
1116+
batch_size=torch.Size([2]),
1117+
device=None),
1118+
count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
1119+
hash=NonTensorStack(
1120+
[4341220243998689835, 6745467818783115365],
1121+
batch_size=torch.Size([2]),
1122+
device=None),
1123+
node_data=LazyStackedTensorDict(
1124+
fields={
1125+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1126+
exclusive_fields={
1127+
},
1128+
batch_size=torch.Size([2]),
1129+
device=cpu,
1130+
is_shared=False,
1131+
stack_dim=0),
1132+
node_id=NonTensorStack(
1133+
[1, 2],
1134+
batch_size=torch.Size([2]),
1135+
device=None),
1136+
rollout=LazyStackedTensorDict(
1137+
fields={
1138+
action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
1139+
next: LazyStackedTensorDict(
1140+
fields={
1141+
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1142+
exclusive_fields={
1143+
},
1144+
batch_size=torch.Size([2, -1]),
1145+
device=cpu,
1146+
is_shared=False,
1147+
stack_dim=0),
1148+
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1149+
exclusive_fields={
1150+
},
1151+
batch_size=torch.Size([2, -1]),
1152+
device=cpu,
1153+
is_shared=False,
1154+
stack_dim=0),
1155+
wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
1156+
index=None,
1157+
subtree=None,
1158+
specs=None,
1159+
batch_size=torch.Size([2]),
1160+
device=None,
1161+
is_shared=False),
1162+
wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1163+
hash=None,
1164+
_parent=None,
1165+
specs=None,
1166+
batch_size=torch.Size([]),
1167+
device=None,
1168+
is_shared=False)
1169+
"""
10431170
source, dest = (
10441171
rollout.exclude("next").copy(),
10451172
rollout.select("next", *self.action_keys).copy(),

0 commit comments

Comments
 (0)