@@ -1040,6 +1040,133 @@ def _make_node_map(self, source, dest):
1040
1040
self .max_size = self .data_map .max_size
1041
1041
1042
1042
def extend (self , rollout , * , return_node : bool = False ):
1043
+ """Add a rollout to the forest.
1044
+
1045
+ Nodes are only added to a tree at points where rollouts diverge from
1046
+ each other and at the endpoints of rollouts.
1047
+
1048
+ If there is no existing tree that matches the first steps of the
1049
+ rollout, a new tree is added. Only one node is created, for the final
1050
+ step.
1051
+
1052
+ If there is an existing tree that matches, the rollout is added to that
1053
+ tree. If the rollout diverges from all other rollouts in the tree at
1054
+ some step, a new node is created before the step where the rollouts
1055
+ diverge, and a leaf node is created for the final step of the rollout.
1056
+ If all of the rollout's steps match with a previously added rollout,
1057
+ nothing changes. If the rollout matches up to a leaf node of a tree but
1058
+ continues beyond it, that node is extended to the end of the rollout,
1059
+ and no new nodes are created.
1060
+
1061
+ Args:
1062
+ rollout (TensorDict): The rollout to add to the forest.
1063
+ return_node (bool, optional): If True, the method returns the added
1064
+ node. Default is ``False``.
1065
+
1066
+ Returns:
1067
+ Tree: The node that was added to the forest. This is only
1068
+ returned if ``return_node`` is True.
1069
+
1070
+ Examples:
1071
+ >>> from torchrl.data import MCTSForest
1072
+ >>> from tensordict import TensorDict
1073
+ >>> import torch
1074
+ >>> forest = MCTSForest()
1075
+ >>> r0 = TensorDict({
1076
+ ... 'action': torch.tensor([1, 2, 3, 4, 5]),
1077
+ ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
1078
+ ... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
1079
+ ... }, [5])
1080
+ >>> r1 = TensorDict({
1081
+ ... 'action': torch.tensor([1, 2, 6, 7]),
1082
+ ... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
1083
+ ... 'observation': torch.tensor([ 0, 123, 392, 235])
1084
+ ... }, [4])
1085
+ >>> td_root = r0[0].exclude("next")
1086
+ >>> forest.extend(r0)
1087
+ >>> forest.extend(r1)
1088
+ >>> tree = forest.get_tree(td_root)
1089
+ >>> print(tree)
1090
+ Tree(
1091
+ count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
1092
+ index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1093
+ node_data=TensorDict(
1094
+ fields={
1095
+ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
1096
+ batch_size=torch.Size([]),
1097
+ device=cpu,
1098
+ is_shared=False),
1099
+ node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
1100
+ rollout=TensorDict(
1101
+ fields={
1102
+ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1103
+ next: TensorDict(
1104
+ fields={
1105
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1106
+ batch_size=torch.Size([2]),
1107
+ device=cpu,
1108
+ is_shared=False),
1109
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1110
+ batch_size=torch.Size([2]),
1111
+ device=cpu,
1112
+ is_shared=False),
1113
+ subtree=Tree(
1114
+ _parent=NonTensorStack(
1115
+ [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
1116
+ batch_size=torch.Size([2]),
1117
+ device=None),
1118
+ count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
1119
+ hash=NonTensorStack(
1120
+ [4341220243998689835, 6745467818783115365],
1121
+ batch_size=torch.Size([2]),
1122
+ device=None),
1123
+ node_data=LazyStackedTensorDict(
1124
+ fields={
1125
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1126
+ exclusive_fields={
1127
+ },
1128
+ batch_size=torch.Size([2]),
1129
+ device=cpu,
1130
+ is_shared=False,
1131
+ stack_dim=0),
1132
+ node_id=NonTensorStack(
1133
+ [1, 2],
1134
+ batch_size=torch.Size([2]),
1135
+ device=None),
1136
+ rollout=LazyStackedTensorDict(
1137
+ fields={
1138
+ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
1139
+ next: LazyStackedTensorDict(
1140
+ fields={
1141
+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1142
+ exclusive_fields={
1143
+ },
1144
+ batch_size=torch.Size([2, -1]),
1145
+ device=cpu,
1146
+ is_shared=False,
1147
+ stack_dim=0),
1148
+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1149
+ exclusive_fields={
1150
+ },
1151
+ batch_size=torch.Size([2, -1]),
1152
+ device=cpu,
1153
+ is_shared=False,
1154
+ stack_dim=0),
1155
+ wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
1156
+ index=None,
1157
+ subtree=None,
1158
+ specs=None,
1159
+ batch_size=torch.Size([2]),
1160
+ device=None,
1161
+ is_shared=False),
1162
+ wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1163
+ hash=None,
1164
+ _parent=None,
1165
+ specs=None,
1166
+ batch_size=torch.Size([]),
1167
+ device=None,
1168
+ is_shared=False)
1169
+ """
1043
1170
source , dest = (
1044
1171
rollout .exclude ("next" ).copy (),
1045
1172
rollout .select ("next" , * self .action_keys ).copy (),
0 commit comments