-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathclassify.py
204 lines (160 loc) · 6.58 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import json
from nltk.tokenize import PunktSentenceTokenizer, TreebankWordTokenizer
from typing import Tuple, Dict
import torch
from torch import nn
from datasets import get_clean_text, get_label_map, load_data
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# path to the checkpoint
checkpoint_path = '/Users/zou/Renovamen/Developing/Text-Classification/checkpoints/checkpoint_fasttext_agnews.pth.tar'
# pad limits
# only makes sense when model_name = 'han'
sentence_limit_per_doc = 15
word_limit_per_sentence = 20
# only makes sense when model_name != 'han'
word_limit = 200
def prepro_doc(
document: str, word_map: Dict[str, int]
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
"""
Preprocess a document into a hierarchial representation
Parameters
----------
document : str
A document in text form
word_map : Dict[str, int]
Word2ix map
Returns
-------
encoded_doc : torch.LongTensor
Pre-processed tokenized document
sentences_per_doc : torch.LongTensor
Document lengths
words_per_each_sentence : torch.LongTensor
Sentence lengths
"""
# tokenizers
sent_tokenizer = PunktSentenceTokenizer()
word_tokenizer = TreebankWordTokenizer()
# a list to store the document tokenized into words
doc = list()
# tokenize document into sentences
sentences = list()
for paragraph in get_clean_text(document).splitlines():
sentences.extend([s for s in sent_tokenizer.tokenize(paragraph)])
# tokenize sentences into words
for s in sentences[:sentence_limit_per_doc]:
w = word_tokenizer.tokenize(s)[:word_limit_per_sentence]
if len(w) == 0:
continue
doc.append(w)
# number of sentences in the document
sentences_per_doc = len(doc)
sentences_per_doc = torch.LongTensor([sentences_per_doc]).to(device) # (1)
# number of words in each sentence
words_per_each_sentence = list(map(lambda s: len(s), doc))
words_per_each_sentence = torch.LongTensor(words_per_each_sentence).unsqueeze(0).to(device) # (1, n_sentences)
# encode document with indices from the word map
encoded_doc = list(
map(lambda s: list(
map(lambda w: word_map.get(w, word_map['<unk>']), s)
) + [0] * (word_limit_per_sentence - len(s)), doc)
) + [[0] * word_limit_per_sentence] * (sentence_limit_per_doc - len(doc))
encoded_doc = torch.LongTensor(encoded_doc).unsqueeze(0).to(device)
return encoded_doc, sentences_per_doc, words_per_each_sentence
def prepro_sent(
text: str, word_map: Dict[str, int]
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""
Preprocess a sentence
Parameters
----------
text : str
A sentence in text form
word_map : Dict[str, int]
Word2ix map
Returns
-------
encoded_sent : torch.LongTensor
Pre-processed tokenized sentence
words_per_sentence : torch.LongTensor
Sentence lengths
"""
# tokenizers
word_tokenizer = TreebankWordTokenizer()
# tokenize sentences into words
sentence = word_tokenizer.tokenize(text)[:word_limit]
# number of words in sentence
words_per_sentence = len(sentence)
words_per_sentence = torch.LongTensor([words_per_sentence]).to(device) # (1)
# encode sentence with indices from the word map
encoded_sent = list(
map(lambda w: word_map.get(w, word_map['<unk>']), sentence)
) + [0] * (word_limit - len(sentence))
encoded_sent = torch.LongTensor(encoded_sent).unsqueeze(0).to(device)
return encoded_sent, words_per_sentence
def classify(
text: str, model: nn.Module, model_name: str, dataset_name: str, word_map: Dict[str, int]
) -> str:
"""
Classify a text using the given model.
Parameters
----------
text : str
A document or sentence in text form
model : nn.Module
A loaded model
model_name : str
Name of the model
dataset_name : str
Name of the dataset
word_map : Dict[str, int]
Word2ix map
Returns
-------
prediction : str
The predicted category with its probability
"""
_, rev_label_map = get_label_map(dataset_name)
if model_name in ['han']:
# preprocess document
encoded_doc, sentences_per_doc, words_per_each_sentence = prepro_doc(text, word_map)
# run through model
scores, word_alphas, sentence_alphas = model(
encoded_doc,
sentences_per_doc,
words_per_each_sentence
) # (1, n_classes), (1, n_sentences, max_sent_len_in_document), (1, n_sentences)
else:
# preprocess sentence
encoded_sent, words_per_sentence = prepro_sent(text, word_map)
# run through model
scores = model(encoded_sent, words_per_sentence)
scores = scores.squeeze(0) # (n_classes)
scores = nn.functional.softmax(scores, dim=0) # (n_classes)
# find best prediction and its probability
score, prediction = scores.max(dim=0)
prediction = 'Category: {category}, Probability: {score:.2f}%'.format(
category = rev_label_map[prediction.item()],
score = score.item() * 100
)
return prediction
# word_alphas = word_alphas.squeeze(0) # (n_sentences, max_sent_len_in_document)
# sentence_alphas = sentence_alphas.squeeze(0) # (n_sentences)
# words_per_each_sentence = words_per_each_sentence.squeeze(0) # (n_sentences)
# return doc, scores, word_alphas, sentence_alphas, words_per_each_sentence
if __name__ == '__main__':
text = 'How do computers work? I have a CPU I want to use. But my keyboard and motherboard do not help.\n\n You can just google how computers work. Honestly, its easy.'
# text = 'But think about it! It\'s so cool. Physics is really all about math. what feynman said, hehe'
# text = "I think I'm falling sick. There was some indigestion at first. But now a fever is beginning to take hold."
# text = "I want to tell you something important. Get into the stock market and investment funds. Make some money so you can buy yourself some yogurt."
# text = "You know what's wrong with this country? republicans and democrats. always at each other's throats\n There's no respect, no bipartisanship."
# load model and word map
model, model_name, _, dataset_name, word_map, _ = load_checkpoint(checkpoint_path, device)
model = model.to(device)
model.eval()
# visualize_attention(*classify(text, model, model_name, dataset_name, word_map))
prediction = classify(text, model, model_name, dataset_name, word_map)
print(prediction)