Skip to content

Commit aa9cf79

Browse files
committed
[Feature] Set padded token log-prob to 0.0
ghstack-source-id: 30c35d539bd7ac7bf2e5691e5ca2bb402219a2a0 Pull Request resolved: #2856
1 parent 4bb8a4f commit aa9cf79

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

torchrl/modules/llm/vllm_policy.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
TensorDictSequential as Seq,
2525
WrapModule,
2626
)
27-
from tensordict.utils import _zip_strict
27+
from tensordict.utils import _zip_strict, expand_as_right
2828

2929
from torchrl.data import LLMData
3030

@@ -130,6 +130,9 @@ def from_vllm(
130130
token_key: NestedKey = ("tokens",)
131131
attention_mask_key: NestedKey = ("attention_mask",)
132132

133+
# retrieve the padding value - we use this to make the log-probs of pad token = 1
134+
padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0]
135+
133136
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
134137
if tokenizer is None:
135138
tokenizer = model.get_tokenizer()
@@ -264,8 +267,6 @@ def to_list(tokens, attention_mask):
264267
strict=True,
265268
)
266269

267-
padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0]
268-
269270
def get_output_tokens_and_log_probs(td, padding_value=padding_value):
270271
td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"])
271272
if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict):
@@ -280,10 +281,18 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
280281
layout=torch.strided
281282
).to_padded_tensor(padding=padding_value)
282283
tokens_response_td.rename_key_("token_ids", "tokens_response")
283-
# td["tokens_response"] = outputs.token_ids
284284
if return_log_probs:
285+
padded_values = tokens_response_td["tokens_response"] == padding_value
285286
tokens_response_td.rename_key_("logprobs", "log_probs")
286-
# td["log_probs"] = outputs.logprobs.unsqueeze(-1)
287+
if padded_values.any():
288+
print(
289+
"padded_values:",
290+
padded_values.sum(),
291+
torch.where(padded_values),
292+
)
293+
lps = tokens_response_td["log_probs"]
294+
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
295+
tokens_response_td["log_probs"] = lps
287296
td.update(tokens_response_td)
288297
elif not generate:
289298
td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1)
@@ -295,7 +304,10 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
295304

296305
def translate_lps(tokens_response, x):
297306
# we disregard the tokens from the prompt to focus on those of the response
298-
return x[..., -tokens_response.shape[-1] :, :]
307+
padded = tokens_response == padding_value
308+
lps = x[..., -tokens_response.shape[-1] :, :]
309+
lps[padded] = 0.0
310+
return x
299311

300312
module_dict["translate_lps"] = Mod(
301313
translate_lps,

0 commit comments

Comments
 (0)