|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +""" |
| 7 | +Demonstrates how to use torchrl's RayReplayBuffer to store and sample data across nodes, specifically in the context of Large Language Models (LLMs). |
| 8 | +
|
| 9 | +This script showcases a simple producer-consumer setup where one node generates trajectories |
| 10 | +from a dialogue dataset and stores them in a shared replay buffer, while another node samples |
| 11 | +data from this buffer. |
| 12 | +
|
| 13 | +The `Trajectory` class represents a single trajectory, containing prompt, response, tokens, |
| 14 | +and other relevant information. The `producer` function generates these trajectories and |
| 15 | +extends the replay buffer, while the `consumer` function samples data from the buffer. |
| 16 | +
|
| 17 | +The script handles tensors with ragged dimensions. They are stored in lazy stacks of tensordicts (or more specifically, |
| 18 | +tensorclasses). Getting the strings returns a list, whereas getting the tensors will raise an error, unless the |
| 19 | +format is specified (see examples). |
| 20 | +
|
| 21 | +""" |
| 22 | + |
| 23 | +import time |
| 24 | +from functools import partial |
| 25 | + |
| 26 | +import ray |
| 27 | +import torch |
| 28 | +from tensordict import lazy_stack, TensorClass |
| 29 | + |
| 30 | +from torchrl._utils import logger as torchrl_logger |
| 31 | +from torchrl.data import LazyStackStorage, RayReplayBuffer |
| 32 | + |
| 33 | + |
| 34 | +class Trajectory(TensorClass["nocast"]): |
| 35 | + # A string or list of strings with the prompts |
| 36 | + prompt: str |
| 37 | + # A string or list of strings with the responses |
| 38 | + response: str |
| 39 | + # A ragged tensor with tokens |
| 40 | + tokens: torch.Tensor |
| 41 | + # A ragged tensor with tokens (responses) |
| 42 | + tokens_response: torch.Tensor |
| 43 | + # A ragged tensor with log-probs (same size as tokens_responses) |
| 44 | + logits: torch.Tensor | None = None |
| 45 | + # A ragged tensor with per-token reward |
| 46 | + rewards: torch.Tensor | None = None |
| 47 | + |
| 48 | + |
| 49 | +@ray.remote(num_cpus=1) |
| 50 | +def producer(rb): |
| 51 | + from datasets import load_dataset |
| 52 | + |
| 53 | + # Get some tokenizer |
| 54 | + from transformers import AutoTokenizer |
| 55 | + |
| 56 | + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") |
| 57 | + dataset = load_dataset("daily_dialog", trust_remote_code=True)["train"] |
| 58 | + data = [] |
| 59 | + for i, dialog in enumerate(dataset): |
| 60 | + # Assuming each dialog is a list of utterances |
| 61 | + for j in range(len(dialog["dialog"]) - 1): |
| 62 | + prompt = dialog["dialog"][j] |
| 63 | + response = dialog["dialog"][j + 1] |
| 64 | + tokens = tokenizer.encode(prompt, return_tensors="pt") |
| 65 | + tokens_response = tokenizer.encode(response, return_tensors="pt") |
| 66 | + logits = torch.randn_like(tokens_response, dtype=torch.float16) |
| 67 | + data.append( |
| 68 | + Trajectory( |
| 69 | + prompt=prompt, |
| 70 | + response=response, |
| 71 | + tokens=tokens.squeeze(), |
| 72 | + tokens_response=tokens_response.squeeze(), |
| 73 | + logits=logits, |
| 74 | + ) |
| 75 | + ) |
| 76 | + if i == 256: |
| 77 | + break |
| 78 | + data = lazy_stack(data) |
| 79 | + rb.extend(data) |
| 80 | + torchrl_logger.info(f"Extended with {data=}") |
| 81 | + torchrl_logger.info(f"State of buffer at exit time: {rb}") |
| 82 | + |
| 83 | + |
| 84 | +@ray.remote(num_cpus=1) |
| 85 | +def consumer(rb): |
| 86 | + while not rb.write_count: |
| 87 | + torchrl_logger.info("Consumer waiting for data...") |
| 88 | + time.sleep(1) |
| 89 | + for _ in range(1): |
| 90 | + samples = rb.sample() |
| 91 | + torchrl_logger.info(f"Sampling data: {samples}") |
| 92 | + time.sleep(1) |
| 93 | + |
| 94 | + # We can also sample fewer elements by passing the batch-size to the sample method |
| 95 | + samples = rb.sample(4) |
| 96 | + # To get the strings, get can use __getitem__ |
| 97 | + prompt = samples.prompt |
| 98 | + assert len(prompt) == 4 |
| 99 | + assert isinstance(prompt, list) |
| 100 | + response = samples.response |
| 101 | + assert len(response) == 4 |
| 102 | + assert isinstance(response, list) |
| 103 | + # For tokens / tokens_response / logits / rewards, we can chose between nested tensors, lists or padded tensors |
| 104 | + tokens_padded = samples.get( |
| 105 | + "tokens", as_padded_tensor=True, padding_value=0, padding_side="right" |
| 106 | + ) |
| 107 | + tokens_nested = samples.get("tokens", as_nested_tensor=True, layout=torch.jagged) |
| 108 | + tokens_list = samples.get("tokens", as_list=True) |
| 109 | + torchrl_logger.info(f"{tokens_padded=}") |
| 110 | + torchrl_logger.info(f"{tokens_nested=}") |
| 111 | + torchrl_logger.info(f"{tokens_list=}") |
| 112 | + time.sleep(1) |
| 113 | + |
| 114 | + |
| 115 | +if __name__ == "__main__": |
| 116 | + # The RB is its own ray worker |
| 117 | + rb = RayReplayBuffer(storage=partial(LazyStackStorage, 1_000_000), batch_size=128) |
| 118 | + # Pass handler to producer |
| 119 | + producer_handler = producer.remote(rb) |
| 120 | + # Pass handler to consumer |
| 121 | + consumer_handler = consumer.remote(rb) |
| 122 | + ray.get([producer_handler, consumer_handler]) # Wait for both tasks to complete |
| 123 | + ray.shutdown() |
0 commit comments