forked from hjptriplebee/Chinese_poem_generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
160 lines (147 loc) · 6.91 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# coding: UTF-8
'''''''''''''''''''''''''''''''''''''''''''''''''''''
file name: model.py
create time: 2017年06月25日 星期日 10时47分48秒
author: Jipeng Huang
e-mail: [email protected]
github: https://github.com/hjptriplebee
'''''''''''''''''''''''''''''''''''''''''''''''''''''
import tensorflow as tf
import numpy as np
from config import *
def buildModel(wordNum, gtX, hidden_units = 128, layers = 2):
"""build rnn"""
with tf.variable_scope("embedding"): #embedding
embedding = tf.get_variable("embedding", [wordNum, hidden_units], dtype = tf.float32)
inputbatch = tf.nn.embedding_lookup(embedding, gtX)
basicCell = tf.contrib.rnn.BasicLSTMCell(hidden_units, state_is_tuple = True)
stackCell = tf.contrib.rnn.MultiRNNCell([basicCell] * layers)
initState = stackCell.zero_state(np.shape(gtX)[0], tf.float32)
outputs, finalState = tf.nn.dynamic_rnn(stackCell, inputbatch, initial_state = initState)
outputs = tf.reshape(outputs, [-1, hidden_units])
with tf.variable_scope("softmax"):
w = tf.get_variable("w", [hidden_units, wordNum])
b = tf.get_variable("b", [wordNum])
logits = tf.matmul(outputs, w) + b
probs = tf.nn.softmax(logits)
return logits, probs, stackCell, initState, finalState
def train(X, Y, wordNum, reload=True):
"""train model"""
gtX = tf.placeholder(tf.int32, shape=[batchSize, None]) # input
gtY = tf.placeholder(tf.int32, shape=[batchSize, None]) # output
logits, probs, a, b, c = buildModel(wordNum, gtX)
targets = tf.reshape(gtY, [-1])
#loss
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets],
[tf.ones_like(targets, dtype=tf.float32)], wordNum)
cost = tf.reduce_mean(loss)
tvars = tf.trainable_variables()
grads, a = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
learningRate = learningRateBase
optimizer = tf.train.AdamOptimizer(learningRate)
trainOP = optimizer.apply_gradients(zip(grads, tvars))
globalStep = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if reload:
checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
# if have checkPoint, restore checkPoint
if checkPoint and checkPoint.model_checkpoint_path:
saver.restore(sess, checkPoint.model_checkpoint_path)
print("restored %s" % checkPoint.model_checkpoint_path)
else:
print("no checkpoint found!")
for epoch in range(epochNum):
if globalStep % learningRateDecreaseStep == 0: #learning rate decrease by epoch
learningRate = learningRateBase * (0.95 ** epoch)
epochSteps = len(X) # equal to batch
for step, (x, y) in enumerate(zip(X, Y)):
#print(x)
#print(y)
globalStep = epoch * epochSteps + step
a, loss = sess.run([trainOP, cost], feed_dict = {gtX:x, gtY:y})
print("epoch: %d steps:%d/%d loss:%3f" % (epoch,step,epochSteps,loss))
if globalStep%1000==0:
print("save model")
saver.save(sess,checkpointsPath + "/poem",global_step=epoch)
def probsToWord(weights, words):
"""probs to word"""
t = np.cumsum(weights) #prefix sum
s = np.sum(weights)
coff = np.random.rand(1)
index = int(np.searchsorted(t, coff * s)) # large margin has high possibility to be sampled
return words[index]
def test(wordNum, wordToID, words):
"""generate poem"""
gtX = tf.placeholder(tf.int32, shape=[1, None]) # input
logits, probs, stackCell, initState, finalState = buildModel(wordNum, gtX)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
# if have checkPoint, restore checkPoint
if checkPoint and checkPoint.model_checkpoint_path:
saver.restore(sess, checkPoint.model_checkpoint_path)
print("restored %s" % checkPoint.model_checkpoint_path)
else:
print("no checkpoint found!")
exit(0)
poems = []
for i in range(generateNum):
state = sess.run(stackCell.zero_state(1, tf.float32))
x = np.array([[wordToID['[']]]) # init start sign
probs1, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
word = probsToWord(probs1, words)
poem = ''
while word != ']' and word != ' ':
poem += word
if word == '。':
poem += '\n'
x = np.array([[wordToID[word]]])
#print(word)
probs2, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
word = probsToWord(probs2, words)
print(poem)
poems.append(poem)
return poems
def testHead(wordNum, wordToID, words, characters):
gtX = tf.placeholder(tf.int32, shape=[1, None]) # input
logits, probs, stackCell, initState, finalState = buildModel(wordNum, gtX)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
checkPoint = tf.train.get_checkpoint_state(checkpointsPath)
# if have checkPoint, restore checkPoint
if checkPoint and checkPoint.model_checkpoint_path:
saver.restore(sess, checkPoint.model_checkpoint_path)
print("restored %s" % checkPoint.model_checkpoint_path)
else:
print("no checkpoint found!")
exit(0)
flag = 1
endSign = {-1: ",", 1: "。"}
poem = ''
state = sess.run(stackCell.zero_state(1, tf.float32))
x = np.array([[wordToID['[']]])
probs1, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
for c in characters:
word = c
flag = -flag
while word != ']' and word != ',' and word != '。' and word != ' ':
poem += word
x = np.array([[wordToID[word]]])
probs2, state = sess.run([probs, finalState], feed_dict={gtX: x, initState: state})
word = probsToWord(probs2, words)
poem += endSign[flag]
# keep the context, state must be updated
if endSign[flag] == '。':
probs2, state = sess.run([probs, finalState],
feed_dict={gtX: np.array([[wordToID["。"]]]), initState: state})
poem += '\n'
else:
probs2, state = sess.run([probs, finalState],
feed_dict={gtX: np.array([[wordToID[","]]]), initState: state})
print(characters)
print(poem)
return poem