-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate.py
53 lines (45 loc) · 1.65 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pickle
from mrrnn import MRRNN
from mrrnn import Configuration
# TOKEN_ID for end of utterance
LANGUAGE_END = 18575
COARSE_END = 10
if __name__ == "__main__":
# import dictionaries
dictionary_path = "./data/Dataset.dict.pkl"
with open(dictionary_path,"r") as file:
vocab_word = pickle.load(file)
vocab_word = sorted( vocab_word, key=lambda tup: tup[1] )
dictionary_coarse_path = "./data/abstract.dict.pkl"
with open(dictionary_coarse_path,"r") as file:
vocab_coarse = pickle.load(file)
vocab_coarse = sorted( vocab_coarse, key=lambda tup: tup[1] )
# import test data
test_word_path = "./data/Test.dialogues.pkl"
with open(test_word_path,"r") as file:
test_word_data = pickle.load(file)
test_coarse_path = "./data/abstract.test.dialogues.pkl"
with open(test_coarse_path,"r") as file:
test_coarse_data = pickle.load(file)
config = Configuration()
config.word_vocab_size = len(vocab_word)
config.coarse_vocab_size = len(vocab_coarse)
config.end_of_word_utt = LANGUAGE_END
config.end_of_coarse_utt = COARSE_END
# create model
model = MRRNN(config)
N_dialogue = 10
file_name = "./ckpts/training_5/trained.ckpt"
model.restore(file_name)
dial_word,dial_coarse = model.split_utterances([test_word_data[N_dialogue]],[test_coarse_data[N_dialogue]])
prediction = model.generate(dial_word[0][:-1],dial_coarse[0][:-1],20,20)
# print dial_word[0]
for curr_utt in xrange(len(dial_word[0])-1):
curr_str = ""
for k in xrange(len(dial_word[0][curr_utt])-1):
curr_str += vocab_word[dial_word[0][curr_utt][k]][0] + " "
print curr_str
curr_str = ""
for k in xrange(len(prediction)):
curr_str += vocab_word[prediction[k]][0] + " "
print curr_str