Skip to content

Commit a0e7a80

Browse files
committed
[Tutorial] Beam search with GPT models
ghstack-source-id: 62f96bf1965a65ca35485de6ee66260abe33f117 Pull Request resolved: #2623
1 parent 1cffffe commit a0e7a80

File tree

18 files changed

+773
-70
lines changed

18 files changed

+773
-70
lines changed

docs/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ vmas
2828
onnxscript
2929
onnxruntime
3030
onnx
31+
plotly
32+
igraph
33+
transformers
34+
datasets
318 KB
Loading

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Intermediate
105105
tutorials/dqn_with_rnn
106106
tutorials/rb_tutorial
107107
tutorials/export
108+
tutorials/beam_search_with_gpt
108109

109110
Advanced
110111
--------

test/mocking_classes.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1776,14 +1776,18 @@ def __init__(self):
17761776
tensor=Unbounded(3),
17771777
non_tensor=NonTensor(shape=()),
17781778
)
1779+
self._saved_obs_spec = self.observation_spec.clone()
17791780
self.state_spec = Composite(
17801781
non_tensor=NonTensor(shape=()),
17811782
)
1783+
self._saved_state_spec = self.state_spec.clone()
17821784
self.reward_spec = Unbounded(1)
1785+
self._saved_full_reward_spec = self.full_reward_spec.clone()
17831786
self.action_spec = Unbounded(1)
1787+
self._saved_full_action_spec = self.full_action_spec.clone()
17841788

17851789
def _reset(self, tensordict):
1786-
data = self.observation_spec.zero()
1790+
data = self._saved_obs_spec.zero()
17871791
data.set_non_tensor("non_tensor", 0)
17881792
data.update(self.full_done_spec.zero())
17891793
return data
@@ -1792,10 +1796,10 @@ def _step(
17921796
self,
17931797
tensordict: TensorDictBase,
17941798
) -> TensorDictBase:
1795-
data = self.observation_spec.zero()
1799+
data = self._saved_obs_spec.zero()
17961800
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
17971801
data.update(self.full_done_spec.zero())
1798-
data.update(self.full_reward_spec.zero())
1802+
data.update(self._saved_full_reward_spec.zero())
17991803
return data
18001804

18011805
def _set_seed(self, seed: Optional[int]):

test/test_env.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3528,8 +3528,13 @@ def test_single_env_spec():
35283528
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
35293529

35303530

3531-
def test_auto_spec():
3532-
env = CountingEnv()
3531+
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
3532+
def test_auto_spec(env_type):
3533+
if env_type is EnvWithMetadata:
3534+
obs_vals = ["tensor", "non_tensor"]
3535+
else:
3536+
obs_vals = "observation"
3537+
env = env_type()
35333538
td = env.reset()
35343539

35353540
policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
@@ -3552,7 +3557,7 @@ def test_auto_spec():
35523557
shape=env.full_state_spec.shape, device=env.full_state_spec.device
35533558
)
35543559
env._action_keys = ["action"]
3555-
env.auto_specs_(policy, tensordict=td.copy())
3560+
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
35563561
env.check_env_specs(tensordict=td.copy())
35573562

35583563

torchrl/_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def _can_be_pickled(obj):
829829
def _make_ordinal_device(device: torch.device):
830830
if device is None:
831831
return device
832+
device = torch.device(device)
832833
if device.type == "cuda" and device.index is None:
833834
return torch.device("cuda", index=torch.cuda.current_device())
834835
if device.type == "mps" and device.index is None:

torchrl/data/map/hash.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
7575
class SipHash(Module):
7676
"""A Module to Compute SipHash values for given tensors.
7777
78-
A hash function module based on SipHash implementation in python.
78+
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
79+
and the output shape will be ``[batch_size]``.
7980
8081
Args:
8182
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers

torchrl/data/map/tdstorage.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def from_tensordict_pair(
177177
collate_fn: Callable[[Any], Any] | None = None,
178178
write_fn: Callable[[Any, Any], Any] | None = None,
179179
consolidated: bool | None = None,
180-
):
180+
) -> TensorDictMap:
181181
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
182182
183183
Args:
@@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308308
if not self._has_lazy_out_keys():
309309
# TODO: make this work with pytrees and avoid calling select if keys match
310310
value = value.select(*self.out_keys, strict=False)
311+
item, value = self._maybe_add_batch(item, value)
312+
index = self._to_index(item, extend=True)
313+
if index.unique().numel() < index.numel():
314+
# If multiple values point to the same place in the storage, we cannot process them by batch
315+
# There could be a better way to deal with this, using unique ids.
316+
vals = []
317+
for it, val in zip(item.split(1), value.split(1)):
318+
self[it] = val
319+
vals.append(val)
320+
# __setitem__ may affect the content of the input data
321+
value.update(TensorDictBase.lazy_stack(vals))
322+
return
311323
if self.write_fn is not None:
324+
# We use this block in the following context: the value written in the storage is already present,
325+
# but it needs to be updated.
326+
# We first check if the value is already there using `contains`. If so, we pass the new value and the
327+
# previous one to write_fn. The values that are not present are passed alone.
312328
if len(self):
313329
modifiable = self.contains(item)
314330
if modifiable.any():
@@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322338
value = self.write_fn(value)
323339
else:
324340
value = self.write_fn(value)
325-
item, value = self._maybe_add_batch(item, value)
326-
index = self._to_index(item, extend=True)
327341
self.storage.set(index, value)
328342

329343
def __len__(self):

0 commit comments

Comments
 (0)