@@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
5353 - "tokens_out", "scores"
5454
5555 """
56- # TODO: how do we avoid getting these?
5756 tokens_out = td ["tokens_out" , "sequences" ]
5857 seq_len = tokens_out .shape [1 ]
5958
6059 del td ["tokens_out" , "past_key_values" ]
60+
6161 scores = dict (td ["tokens_out" , "scores" ].items ())
6262 scores = torch .stack (
6363 [scores [str (k )] for k in range (len (scores ))], 1
@@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
9090 - "forward", "past_key_values"
9191 - "forward"
9292 """
93- # TODO: how do we avoid getting these?
93+ tokens_out = td ["tokens_response" , "input_ids" ]
94+ seq_len = tokens_out .shape [- 1 ]
95+
9496 del td ["forward" , "past_key_values" ]
97+
9598 scores = td ["forward" , "logits" ]
99+ scores = scores [..., - seq_len :, :]
96100 logits = scores - scores .logsumexp (dim = - 1 , keepdim = True )
97101 td ["logits" ] = scores
98102 del td ["forward" ]
99103 scores .shape [1 ]
100- tokens = td ["tokens_in" , "input_ids" ]
101- log_probs = logits .gather (- 1 , tokens .unsqueeze (- 1 ))
104+ log_probs = logits .gather (- 1 , tokens_out .unsqueeze (- 1 ))
102105 td ["log_probs" ] = log_probs
103106 return td
104107
0 commit comments