24
24
TensorDictSequential as Seq ,
25
25
WrapModule ,
26
26
)
27
- from tensordict .utils import _zip_strict
27
+ from tensordict .utils import _zip_strict , expand_as_right
28
28
29
29
from torchrl .data import LLMData
30
30
@@ -130,6 +130,9 @@ def from_vllm(
130
130
token_key : NestedKey = ("tokens" ,)
131
131
attention_mask_key : NestedKey = ("attention_mask" ,)
132
132
133
+ # retrieve the padding value - we use this to make the log-probs of pad token = 1
134
+ padding_value = tokenizer (tokenizer .pad_token )["input_ids" ][0 ]
135
+
133
136
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
134
137
if tokenizer is None :
135
138
tokenizer = model .get_tokenizer ()
@@ -264,8 +267,6 @@ def to_list(tokens, attention_mask):
264
267
strict = True ,
265
268
)
266
269
267
- padding_value = tokenizer (tokenizer .pad_token )["input_ids" ][0 ]
268
-
269
270
def get_output_tokens_and_log_probs (td , padding_value = padding_value ):
270
271
td ["tokens_out" ] = _RequestOutput_tc .from_request_output (td ["tokens_out" ])
271
272
if pad_output and td .ndim and not isinstance (td , LazyStackedTensorDict ):
@@ -280,10 +281,18 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
280
281
layout = torch .strided
281
282
).to_padded_tensor (padding = padding_value )
282
283
tokens_response_td .rename_key_ ("token_ids" , "tokens_response" )
283
- # td["tokens_response"] = outputs.token_ids
284
284
if return_log_probs :
285
+ padded_values = tokens_response_td ["tokens_response" ] == padding_value
285
286
tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
286
- # td["log_probs"] = outputs.logprobs.unsqueeze(-1)
287
+ if padded_values .any ():
288
+ print (
289
+ "padded_values:" ,
290
+ padded_values .sum (),
291
+ torch .where (padded_values ),
292
+ )
293
+ lps = tokens_response_td ["log_probs" ]
294
+ lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
295
+ tokens_response_td ["log_probs" ] = lps
287
296
td .update (tokens_response_td )
288
297
elif not generate :
289
298
td ["prompt_logprobs" ] = td ["tokens_out" ].prompt_logprobs .unsqueeze (- 1 )
@@ -295,7 +304,10 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
295
304
296
305
def translate_lps (tokens_response , x ):
297
306
# we disregard the tokens from the prompt to focus on those of the response
298
- return x [..., - tokens_response .shape [- 1 ] :, :]
307
+ padded = tokens_response == padding_value
308
+ lps = x [..., - tokens_response .shape [- 1 ] :, :]
309
+ lps [padded ] = 0.0
310
+ return x
299
311
300
312
module_dict ["translate_lps" ] = Mod (
301
313
translate_lps ,
0 commit comments