Skip to content

Commit b3224ee

Browse files
committed
[Feature] transformers policy
ghstack-source-id: 5a6b6e8eccafb7c89ece8557223ffc678711b449 Pull Request resolved: #2825
1 parent 45ad106 commit b3224ee

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

torchrl/modules/llm/transformers.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# TODO: lazy imports
7+
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq, TensorDictModuleBase, WrapModule
10+
from tensordict import NestedKey, TensorDictBase, TensorDict
11+
import transformers
12+
import torch
13+
14+
def _maybe_clear_device(td):
15+
if td.device is None:
16+
return td
17+
return td.set(NonTensorData("_source_device"), td.device).clear_device_()
18+
19+
20+
def _maybe_set_device(td):
21+
device = td.pop("_source_device", None)
22+
if device is None:
23+
return td
24+
elif isinstance(device, NonTensorData):
25+
device: torch.device = device.data
26+
return td.to(device)
27+
28+
29+
def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
30+
# TODO: how do we avoid getting these?
31+
del td["tokens_out", "past_key_values"]
32+
scores = dict(td["tokens_out", "scores"].items())
33+
scores = torch.stack([scores[str(k)] for k in range(len(scores))], 1) # shape (B, seq-len, vocab_size)
34+
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
35+
td["logits"] = scores
36+
del td["tokens_out", "scores"]
37+
seq_len = scores.shape[1]
38+
tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len)
39+
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
40+
td["log_probs"] = log_probs
41+
return td
42+
43+
def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
44+
# TODO: how do we avoid getting these?
45+
del td["forward", "past_key_values"]
46+
scores = td["forward", "logits"]
47+
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
48+
td["logits"] = scores
49+
del td["forward"]
50+
seq_len = scores.shape[1]
51+
tokens = td["tokens_in", "input_ids"]
52+
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
53+
td["log_probs"] = log_probs
54+
return td
55+
56+
57+
def from_hf_transformers(
58+
model: transformers.modeling_utils.PreTrainedModel,
59+
*,
60+
generate: bool = True,
61+
return_log_probs: bool = True,
62+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
63+
from_text: bool = False,
64+
device: torch.device | None = None,
65+
text_key: NestedKey = "text",
66+
input_key: NestedKey = "input_ids",
67+
kwargs: dict | None = None,
68+
tokenizer_kwargs: dict | None = None,
69+
) -> TensorDictModuleBase:
70+
71+
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
72+
73+
module_dict = {}
74+
if device:
75+
module_dict["clear_device"] = _maybe_clear_device
76+
if from_text:
77+
if not tokenizer_kwargs:
78+
tokenizer_kwargs = {}
79+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
80+
raise RuntimeError
81+
if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt":
82+
raise RuntimeError
83+
# TODO: add other paddings
84+
if tokenizer_kwargs.setdefault("padding", True) not in (True,):
85+
raise RuntimeError
86+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
87+
raise RuntimeError
88+
89+
module_dict["encode"] = Mod(
90+
tokenizer,
91+
in_keys=[text_key],
92+
out_keys=["tokens_in"],
93+
method_kwargs=tokenizer_kwargs,
94+
strict=True,
95+
)
96+
if device:
97+
module_dict["to_dest_device"] = Mod(
98+
lambda tensor: tensor.to(device),
99+
in_keys=["tokens_in"],
100+
out_keys=["tokens_in"],
101+
strict=True
102+
)
103+
104+
if generate:
105+
if not kwargs:
106+
kwargs = {}
107+
if return_log_probs:
108+
if not kwargs.setdefault("output_scores", True):
109+
raise RuntimeError
110+
if not kwargs.setdefault("return_dict_in_generate", True):
111+
raise RuntimeError
112+
if kwargs.setdefault("tokenizer", tokenizer) is not tokenizer and tokenizer is not None:
113+
raise RuntimeError
114+
115+
module_dict["generate"] = Mod(
116+
model,
117+
method="generate",
118+
method_kwargs=kwargs,
119+
in_keys={
120+
"input_ids": ("tokens_in", "input_ids"),
121+
"attention_mask": ("tokens_in", "attention_mask"),
122+
},
123+
out_keys=["tokens_out"],
124+
out_to_in_map=True,
125+
strict=True,
126+
)
127+
if return_log_probs:
128+
module_dict["extract_log_probs"] = WrapModule(
129+
log_probs_from_scores,
130+
in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")],
131+
out_keys=["logits", "log_probs"]
132+
)
133+
if from_text:
134+
module_dict["decode"] = Mod(
135+
tokenizer.batch_decode,
136+
in_keys=[("tokens_out", "sequences")],
137+
out_keys=["action"],
138+
strict=True,
139+
)
140+
141+
else:
142+
if not kwargs:
143+
kwargs = {}
144+
if not kwargs.setdefault("return_dict", True):
145+
raise RuntimeError
146+
if not return_log_probs:
147+
raise RuntimeError
148+
module_dict["get_logprobs"] = Mod(
149+
model,
150+
method_kwargs=kwargs,
151+
in_keys={
152+
"input_ids": ("tokens_in", "input_ids"),
153+
"attention_mask": ("tokens_in", "attention_mask"),
154+
},
155+
out_keys=["forward"],
156+
out_to_in_map=True,
157+
strict=True,
158+
)
159+
module_dict["extract_log_probs"] = WrapModule(
160+
log_probs_from_logits,
161+
in_keys=[("tokens_in", "input_ids"), ("forward", "logits")],
162+
out_keys=["logits", "log_probs"]
163+
)
164+
if device:
165+
module_dict["to_source_device"] = _maybe_set_device
166+
return Seq(module_dict)
167+
168+
169+
if __name__ == "__main__":
170+
max_seq_length = 50000
171+
model_name = "Qwen/Qwen2.5-7B-Instruct"
172+
173+
model = AutoModelForCausalLM.from_pretrained(
174+
model_name,
175+
torch_dtype="auto",
176+
device_map="auto"
177+
)
178+
tokenizer = AutoTokenizer.from_pretrained(model_name)
179+
180+
tokenizer.padding_side = "left"
181+
182+
m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True)
183+
td = m(TensorDict(text="a text"))
184+
185+
m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False)
186+
td = m(TensorDict(text="a text"))

0 commit comments

Comments
 (0)