Skip to content

Commit 4745d5d

Browse files
committed
[Feature] Allow using lists of tensors in vllm instead of padded tensors
ghstack-source-id: 46615f1 Pull Request resolved: #2861
1 parent e1d3fd4 commit 4745d5d

File tree

5 files changed

+471
-4
lines changed

5 files changed

+471
-4
lines changed
+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from argparse import ArgumentParser
6+
7+
import torch
8+
from datasets import load_dataset
9+
from grpo_utils import (
10+
HF2vLLMLocalWeightUpdater,
11+
PrepareQuestion,
12+
ShapedCorrectnessReward,
13+
)
14+
from tensordict import TensorDict
15+
from torch.utils._pytree import tree_map
16+
from torch.utils.data import DataLoader
17+
from torchrl.collectors import SyncDataCollector
18+
from torchrl.data import (
19+
LazyStackStorage,
20+
ReplayBuffer,
21+
SamplerWithoutReplacement,
22+
)
23+
from torchrl.envs import KLRewardTransform, LLMEnv, StepCounter
24+
from torchrl.modules import TransformersWrapper, vLLMWrapper
25+
from torchrl.objectives import ClipPPOLoss
26+
from torchrl.record import WandbLogger
27+
from transformers import GPT2LMHeadModel
28+
from vllm import LLM
29+
30+
parser = ArgumentParser()
31+
parser.add_argument("--dataset", type=str, default="gsm8k")
32+
parser.add_argument("--batch_size", type=int, default=4)
33+
parser.add_argument("--epochs", type=int, default=10)
34+
parser.add_argument("--repeats", type=int, default=10)
35+
parser.add_argument("--steps_per_batch", type=int, default=16)
36+
parser.add_argument("--optim_batch_size", type=int, default=4)
37+
# parser.add_argument("--model_name", type=str, default="gpt2")
38+
parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B")
39+
40+
41+
def compute_mc_advantage(trajectories):
42+
# Get the question
43+
answer = trajectories["answer"]
44+
# Identify indices where the answers match
45+
answer_ids = tree_map(lambda string: hash(string), answer)
46+
answer_ids = torch.tensor(answer_ids)
47+
unique_qs = answer_ids.view(-1).unique()
48+
trajectories["advantage"] = trajectories["next", "reward"] * 0
49+
for u in unique_qs:
50+
idx = answer_ids == u
51+
rewards = trajectories[idx]["next", "reward"]
52+
rewards = (rewards - rewards.mean()) / rewards.std().clamp(min=1e-4)
53+
trajectories.set_at_("advantage", rewards, idx)
54+
return trajectories
55+
56+
57+
if __name__ == "__main__":
58+
args = parser.parse_args()
59+
# Create env instance:
60+
# - Load the gsm8k dataset
61+
dataset = load_dataset(args.dataset, "main")
62+
train_dataset = dataset["train"]
63+
64+
def collate_fn(batch):
65+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
66+
batch.rename_key_("question", "text")
67+
return batch
68+
69+
# LLM
70+
# inference_model = GPT2LMHeadModel(GPT2Config())
71+
inference_model = LLM(args.model_name)
72+
tokenizer = inference_model.get_tokenizer()
73+
tokenizer.pad_token = tokenizer.eos_token
74+
tokenizer.padding_side = "left"
75+
76+
# Env
77+
dataloader = DataLoader( # noqa: TOR401
78+
train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
79+
)
80+
env = LLMEnv.from_dataloader(
81+
dataloader=dataloader,
82+
# tokenizer=tokenizer,
83+
str2str=True,
84+
batch_size=(args.batch_size * args.repeats,),
85+
repeats=args.repeats,
86+
)
87+
env.insert_transform(0, PrepareQuestion())
88+
89+
# Finally, we want the env to stop after the first step
90+
env.append_transform(StepCounter(max_steps=1))
91+
92+
policy = vLLMWrapper(
93+
inference_model,
94+
tokenizer=tokenizer,
95+
from_text=True,
96+
generate=True,
97+
# vLLM log-probs are a bit screwed up, we could use something else
98+
return_log_probs=True,
99+
)
100+
101+
# Reward transform
102+
env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer))
103+
104+
# Ref model
105+
ref_model = GPT2LMHeadModel.from_pretrained(args.model_name).eval()
106+
TensorDict.from_module(ref_model).data.to_module(ref_model)
107+
ref_model = TransformersWrapper(
108+
ref_model,
109+
tokenizer=tokenizer,
110+
from_text=False,
111+
generate=False,
112+
return_log_probs=True,
113+
)
114+
env.append_transform(
115+
KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs")
116+
)
117+
118+
# replay buffer
119+
rb = ReplayBuffer(
120+
storage=LazyStackStorage(args.steps_per_batch),
121+
sampler=SamplerWithoutReplacement(),
122+
batch_size=args.optim_batch_size,
123+
)
124+
125+
# Collector
126+
train_model = GPT2LMHeadModel.from_pretrained(args.model_name).eval()
127+
collector = SyncDataCollector(
128+
env,
129+
policy,
130+
frames_per_batch=args.steps_per_batch,
131+
total_frames=1_000_000,
132+
local_weight_updater=HF2vLLMLocalWeightUpdater(
133+
hf_model=train_model, vllm_model=inference_model
134+
),
135+
use_buffers=False,
136+
)
137+
138+
# Loss module
139+
policy_training = TransformersWrapper(
140+
train_model,
141+
tokenizer=tokenizer,
142+
# We have the tokens, let's just use them
143+
from_text=False,
144+
generate=False,
145+
return_log_probs=True,
146+
)
147+
loss_fn = ClipPPOLoss(
148+
actor_network=policy_training,
149+
critic_network=None,
150+
critic_coef=0.0,
151+
functional=False,
152+
)
153+
loss_fn.set_keys(sample_log_prob="log_probs")
154+
loss_fn._set_in_keys()
155+
optim = torch.optim.Adam(loss_fn.parameters())
156+
157+
# loss_fn = ReinforceLoss(
158+
# actor_network=policy,
159+
# critic_network=None,
160+
# critic_coef=0.0,
161+
# )
162+
163+
logger = WandbLogger(exp_name=args.model_name)
164+
for i, trajs in enumerate(collector):
165+
print("Collected batch", i)
166+
trajs = trajs.reshape(-1)
167+
trajs = compute_mc_advantage(trajs)
168+
rb.extend(trajs)
169+
# logging
170+
reward = torch.cat(rb[:].get(("next", "reward"), as_list=True)).mean()
171+
logger.log_scalar("reward", reward)
172+
for _ in range(args.epochs):
173+
for batch in rb:
174+
loss = loss_fn(batch)
175+
loss_val = loss.mean(reduce=True)
176+
loss_val.backward()
177+
optim.step()
178+
optim.zero_grad()
179+
collector.update_policy_weights_()

0 commit comments

Comments
 (0)