-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
72 lines (61 loc) · 2.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gc
import json
import time
import torch
from flask import Flask
from flask import jsonify
from flask import request
from transformers import AutoTokenizer, AutoModelForCausalLM
def to_tokens_and_logprobs(model, tokenizer, input_texts,device,special_ids):
"""
:param model:
:param tokenizer:
:param input_texts:
:return: [[('One', -5.882715702056885),
(' plus', -9.785109519958496),
(' one', -0.7229145169258118),
(' is', -2.494063377380371),
(' two', -6.137458324432373)],
]
"""
# print(input_texts)
# print(input_ids,other)
input_ids = tokenizer(input_texts,
max_length=1024,
return_tensors="pt"
).input_ids.to(device)
outputs = model(input_ids)
probs = torch.log_softmax(outputs.logits, dim=-1).detach()
# collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
probs = probs[:, :-1, :]
input_ids = input_ids[:, 1::]
gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)
batch = []
tokens=[]
for input_sentence, input_probs in zip(input_ids, gen_probs):
text_sequence = []
for token, p in zip(input_sentence, input_probs):
if token not in tokenizer.all_special_ids+special_ids:
text_sequence.append((tokenizer.decode(token.item()), p.item()))
# tokens.append(token.cpu())
tokens.append(tokenizer.decode(token.item()))
batch.append(text_sequence)
batch = batch[0]
logprobs = [x[1] for x in batch]
top_logprobs_dicts = [[{x[0].strip(): x[1]} for x in batch]]
del input_ids, outputs, probs
torch.cuda.empty_cache()
print("to_tokens_and_logprobs",logprobs,top_logprobs_dicts)
return [tokens],logprobs, top_logprobs_dicts
def convert_tokens(text, tokenizer):
"""
text to tokens
:param text:
:param tokenizer:
:return:
"""
input_ids = tokenizer(text)['input_ids']
tokens = []
for id in input_ids:
tokens.append(tokenizer.decode(id))
return tokens