Skip to content

Commit a31dca3

Browse files
committed
[Example] RayReplayBuffer usage
ghstack-source-id: ab39a46 Pull-Request-resolved: #2949
1 parent cb06ea3 commit a31dca3

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ __pycache__/
99
# Distribution / packaging
1010
.Python
1111
build/
12+
dump/
1213
develop-eggs/
1314
dist/
1415
downloads/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Demonstrates how to use torchrl's RayReplayBuffer to store and sample data across nodes, specifically in the context of Large Language Models (LLMs).
8+
9+
This script showcases a simple producer-consumer setup where one node generates trajectories
10+
from a dialogue dataset and stores them in a shared replay buffer, while another node samples
11+
data from this buffer.
12+
13+
The `Trajectory` class represents a single trajectory, containing prompt, response, tokens,
14+
and other relevant information. The `producer` function generates these trajectories and
15+
extends the replay buffer, while the `consumer` function samples data from the buffer.
16+
17+
The script handles tensors with ragged dimensions. They are stored in lazy stacks of tensordicts (or more specifically,
18+
tensorclasses). Getting the strings returns a list, whereas getting the tensors will raise an error, unless the
19+
format is specified (see examples).
20+
21+
"""
22+
23+
import time
24+
from functools import partial
25+
26+
import ray
27+
import torch
28+
from tensordict import lazy_stack, TensorClass
29+
30+
from torchrl._utils import logger as torchrl_logger
31+
from torchrl.data import LazyStackStorage, RayReplayBuffer
32+
33+
34+
class Trajectory(TensorClass["nocast"]):
35+
# A string or list of strings with the prompts
36+
prompt: str
37+
# A string or list of strings with the responses
38+
response: str
39+
# A ragged tensor with tokens
40+
tokens: torch.Tensor
41+
# A ragged tensor with tokens (responses)
42+
tokens_response: torch.Tensor
43+
# A ragged tensor with log-probs (same size as tokens_responses)
44+
logits: torch.Tensor | None = None
45+
# A ragged tensor with per-token reward
46+
rewards: torch.Tensor | None = None
47+
48+
49+
@ray.remote(num_cpus=1)
50+
def producer(rb):
51+
from datasets import load_dataset
52+
53+
# Get some tokenizer
54+
from transformers import AutoTokenizer
55+
56+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
57+
dataset = load_dataset("daily_dialog", trust_remote_code=True)["train"]
58+
data = []
59+
for i, dialog in enumerate(dataset):
60+
# Assuming each dialog is a list of utterances
61+
for j in range(len(dialog["dialog"]) - 1):
62+
prompt = dialog["dialog"][j]
63+
response = dialog["dialog"][j + 1]
64+
tokens = tokenizer.encode(prompt, return_tensors="pt")
65+
tokens_response = tokenizer.encode(response, return_tensors="pt")
66+
logits = torch.randn_like(tokens_response, dtype=torch.float16)
67+
data.append(
68+
Trajectory(
69+
prompt=prompt,
70+
response=response,
71+
tokens=tokens.squeeze(),
72+
tokens_response=tokens_response.squeeze(),
73+
logits=logits,
74+
)
75+
)
76+
if i == 256:
77+
break
78+
data = lazy_stack(data)
79+
rb.extend(data)
80+
torchrl_logger.info(f"Extended with {data=}")
81+
torchrl_logger.info(f"State of buffer at exit time: {rb}")
82+
83+
84+
@ray.remote(num_cpus=1)
85+
def consumer(rb):
86+
while not rb.write_count:
87+
torchrl_logger.info("Consumer waiting for data...")
88+
time.sleep(1)
89+
for _ in range(1):
90+
samples = rb.sample()
91+
torchrl_logger.info(f"Sampling data: {samples}")
92+
time.sleep(1)
93+
94+
# We can also sample fewer elements by passing the batch-size to the sample method
95+
samples = rb.sample(4)
96+
# To get the strings, get can use __getitem__
97+
prompt = samples.prompt
98+
assert len(prompt) == 4
99+
assert isinstance(prompt, list)
100+
response = samples.response
101+
assert len(response) == 4
102+
assert isinstance(response, list)
103+
# For tokens / tokens_response / logits / rewards, we can chose between nested tensors, lists or padded tensors
104+
tokens_padded = samples.get(
105+
"tokens", as_padded_tensor=True, padding_value=0, padding_side="right"
106+
)
107+
tokens_nested = samples.get("tokens", as_nested_tensor=True, layout=torch.jagged)
108+
tokens_list = samples.get("tokens", as_list=True)
109+
torchrl_logger.info(f"{tokens_padded=}")
110+
torchrl_logger.info(f"{tokens_nested=}")
111+
torchrl_logger.info(f"{tokens_list=}")
112+
time.sleep(1)
113+
114+
115+
if __name__ == "__main__":
116+
# The RB is its own ray worker
117+
rb = RayReplayBuffer(storage=partial(LazyStackStorage, 1_000_000), batch_size=128)
118+
# Pass handler to producer
119+
producer_handler = producer.remote(rb)
120+
# Pass handler to consumer
121+
consumer_handler = consumer.remote(rb)
122+
ray.get([producer_handler, consumer_handler]) # Wait for both tasks to complete
123+
ray.shutdown()

torchrl/collectors/distributed/ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
DEFAULT_REMOTE_CLASS_CONFIG = {
6262
"num_cpus": 1,
63-
"num_gpus": 0.2,
63+
"num_gpus": 0.2 if torch.cuda.is_available() else None,
6464
"memory": 2 * 1024**3,
6565
}
6666

0 commit comments

Comments
 (0)