Skip to content

Commit 1096f96

Browse files
committed
ChatBot 예제 최신 텐서플로에 맞게 수정
1 parent 57dfa7d commit 1096f96

12 files changed

+29
-29
lines changed

10 - RNN/ChatBot/chat.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ def run(self):
2626
line = sys.stdin.readline()
2727

2828
while line:
29-
print(self.get_replay(line.strip()))
29+
print(self._get_replay(line.strip()))
3030

3131
sys.stdout.write("\n> ")
3232
sys.stdout.flush()
3333

3434
line = sys.stdin.readline()
3535

36-
def decode(self, enc_input, dec_input):
36+
def _decode(self, enc_input, dec_input):
3737
if type(dec_input) is np.ndarray:
3838
dec_input = dec_input.tolist()
3939

@@ -46,7 +46,7 @@ def decode(self, enc_input, dec_input):
4646

4747
return self.model.predict(self.sess, [enc_input], [dec_input])
4848

49-
def get_replay(self, msg):
49+
def _get_replay(self, msg):
5050
enc_input = self.dialog.tokenizer(msg)
5151
enc_input = self.dialog.tokens_to_ids(enc_input)
5252
dec_input = []
@@ -57,7 +57,7 @@ def get_replay(self, msg):
5757
# 다만 상황에 따라서는 이런 방식이 더 유연할 수도 있을 듯
5858
curr_seq = 0
5959
for i in range(FLAGS.max_decode_len):
60-
outputs = self.decode(enc_input, dec_input)
60+
outputs = self._decode(enc_input, dec_input)
6161
if self.dialog.is_eos(outputs[0][curr_seq]):
6262
break
6363
elif self.dialog.is_defined(outputs[0][curr_seq]) is not True:
@@ -75,5 +75,6 @@ def main(_):
7575
chatbot = ChatBot(FLAGS.voc_path, FLAGS.train_dir)
7676
chatbot.run()
7777

78+
7879
if __name__ == "__main__":
7980
tf.app.run()

10 - RNN/ChatBot/dialog.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tensorflow as tf
33
import numpy as np
44
import re
5-
import codecs
65

76
from config import FLAGS
87

@@ -32,11 +31,11 @@ def decode(self, indices, string=False):
3231
tokens = [[self.vocab_list[i] for i in dec] for dec in indices]
3332

3433
if string:
35-
return self.decode_to_string(tokens[0])
34+
return self._decode_to_string(tokens[0])
3635
else:
3736
return tokens
3837

39-
def decode_to_string(self, tokens):
38+
def _decode_to_string(self, tokens):
4039
text = ' '.join(tokens)
4140
return text.strip()
4241

@@ -50,7 +49,7 @@ def is_eos(self, voc_id):
5049
def is_defined(self, voc_id):
5150
return voc_id in self._PRE_DEFINED_
5251

53-
def max_len(self, batch_set):
52+
def _max_len(self, batch_set):
5453
max_len_input = 0
5554
max_len_output = 0
5655

@@ -64,7 +63,7 @@ def max_len(self, batch_set):
6463

6564
return max_len_input, max_len_output + 1
6665

67-
def pad(self, seq, max_len, start=None, eos=None):
66+
def _pad(self, seq, max_len, start=None, eos=None):
6867
if start:
6968
padded_seq = [self._STA_ID_] + seq
7069
elif eos:
@@ -77,16 +76,16 @@ def pad(self, seq, max_len, start=None, eos=None):
7776
else:
7877
return padded_seq
7978

80-
def pad_left(self, seq, max_len):
79+
def _pad_left(self, seq, max_len):
8180
if len(seq) < max_len:
8281
return ([self._PAD_ID_] * (max_len - len(seq))) + seq
8382
else:
8483
return seq
8584

8685
def transform(self, input, output, input_max, output_max):
87-
enc_input = self.pad(input, input_max)
88-
dec_input = self.pad(output, output_max, start=True)
89-
target = self.pad(output, output_max, eos=True)
86+
enc_input = self._pad(input, input_max)
87+
dec_input = self._pad(output, output_max, start=True)
88+
target = self._pad(output, output_max, eos=True)
9089

9190
# 구글 방식으로 입력을 인코더에 역순으로 입력한다.
9291
enc_input.reverse()
@@ -117,7 +116,7 @@ def next_batch(self, batch_size):
117116

118117
# TODO: 구글처럼 버킷을 이용한 방식으로 변경
119118
# 간단하게 만들기 위해 구글처럼 버킷을 쓰지 않고 같은 배치는 같은 사이즈를 사용하도록 만듬
120-
max_len_input, max_len_output = self.max_len(batch_set)
119+
max_len_input, max_len_output = self._max_len(batch_set)
121120

122121
for i in range(0, len(batch_set) - 1, 2):
123122
enc, dec, tar = self.transform(batch_set[i], batch_set[i+1],
Binary file not shown.
Binary file not shown.

10 - RNN/ChatBot/model.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def __init__(self, vocab_size, n_hidden=128, n_layers=3):
2525
self.bias = tf.Variable(tf.zeros([self.vocab_size]), name="bias")
2626
self.global_step = tf.Variable(0, trainable=False, name="global_step")
2727

28-
self.build_model()
28+
self._build_model()
2929

3030
self.saver = tf.train.Saver(tf.global_variables())
3131

32-
def build_model(self):
33-
self.enc_input = tf.transpose(self.enc_input, [1, 0, 2])
34-
self.dec_input = tf.transpose(self.dec_input, [1, 0, 2])
32+
def _build_model(self):
33+
# self.enc_input = tf.transpose(self.enc_input, [1, 0, 2])
34+
# self.dec_input = tf.transpose(self.dec_input, [1, 0, 2])
3535

36-
enc_cell, dec_cell = self.build_cells()
36+
enc_cell, dec_cell = self._build_cells()
3737

3838
with tf.variable_scope('encode'):
3939
outputs, enc_states = tf.nn.dynamic_rnn(enc_cell, self.enc_input, dtype=tf.float32)
@@ -42,24 +42,24 @@ def build_model(self):
4242
outputs, dec_states = tf.nn.dynamic_rnn(dec_cell, self.dec_input, dtype=tf.float32,
4343
initial_state=enc_states)
4444

45-
self.logits, self.cost, self.train_op = self.build_ops(outputs, self.targets)
45+
self.logits, self.cost, self.train_op = self._build_ops(outputs, self.targets)
4646

4747
self.outputs = tf.argmax(self.logits, 2)
4848

49-
def cell(self, n_hidden, output_keep_prob):
50-
rnn_cell = tf.contrib.rnn.BasicRNNCell(self.n_hidden)
51-
rnn_cell = tf.contrib.rnn.DropoutWrapper(rnn_cell, output_keep_prob=output_keep_prob)
49+
def _cell(self, output_keep_prob):
50+
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(self.n_hidden)
51+
rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, output_keep_prob=output_keep_prob)
5252
return rnn_cell
5353

54-
def build_cells(self, output_keep_prob=0.5):
55-
enc_cell = tf.contrib.rnn.MultiRNNCell([self.cell(self.n_hidden, output_keep_prob)
54+
def _build_cells(self, output_keep_prob=0.5):
55+
enc_cell = tf.nn.rnn_cell.MultiRNNCell([self._cell(output_keep_prob)
5656
for _ in range(self.n_layers)])
57-
dec_cell = tf.contrib.rnn.MultiRNNCell([self.cell(self.n_hidden, output_keep_prob)
57+
dec_cell = tf.nn.rnn_cell.MultiRNNCell([self._cell(output_keep_prob)
5858
for _ in range(self.n_layers)])
5959

6060
return enc_cell, dec_cell
6161

62-
def build_ops(self, outputs, targets):
62+
def _build_ops(self, outputs, targets):
6363
time_steps = tf.shape(outputs)[1]
6464
outputs = tf.reshape(outputs, [-1, self.n_hidden])
6565

10 - RNN/ChatBot/model/checkpoint

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
model_checkpoint_path: "conversation.ckpt-5000"
2-
all_model_checkpoint_paths: "conversation.ckpt-5000"
1+
model_checkpoint_path: "conversation.ckpt-10000"
2+
all_model_checkpoint_paths: "conversation.ckpt-10000"
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)