From 2b5b3359c6bdf82691defa98d6ba423456c27318 Mon Sep 17 00:00:00 2001 From: vanechu Date: Wed, 11 May 2016 23:28:26 +0900 Subject: [PATCH] Implement temperature --- model.py | 3 ++- sample.py | 4 +++- train.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 5e7e4979..008dbc50 100644 --- a/model.py +++ b/model.py @@ -58,7 +58,7 @@ def loop(prev, _): optimizer = tf.train.AdamOptimizer(self.lr) self.train_op = optimizer.apply_gradients(zip(grads, tvars)) - def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): + def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1, temperature=1.): state = self.cell.zero_state(1, tf.float32).eval() for char in prime[:-1]: x = np.zeros((1, 1)) @@ -77,6 +77,7 @@ def weighted_pick(weights): x = np.zeros((1, 1)) x[0, 0] = vocab[char] feed = {self.input_data: x, self.initial_state:state} + self.probs = tf.nn.softmax(tf.div(self.logits, temperature)) [probs, state] = sess.run([self.probs, self.final_state], feed) p = probs[0] diff --git a/sample.py b/sample.py index 7c0e0ba9..4748c9d9 100644 --- a/sample.py +++ b/sample.py @@ -20,6 +20,8 @@ def main(): help='prime text') parser.add_argument('--sample', type=int, default=1, help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces') + parser.add_argument('--temperature', type=float, default=1., + help='temperature for sampling, within the range of (0,1]') args = parser.parse_args() sample(args) @@ -36,7 +38,7 @@ def sample(args): ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) - print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample)) + print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample, args.temperature)) if __name__ == '__main__': main() diff --git a/train.py b/train.py index ac49fd9b..d446081c 100644 --- a/train.py +++ b/train.py @@ -35,7 +35,7 @@ def main(): parser.add_argument('--learning_rate', type=float, default=0.002, help='learning rate') parser.add_argument('--decay_rate', type=float, default=0.97, - help='decay rate for rmsprop') + help='decay rate for rmsprop') parser.add_argument('--init_from', type=str, default=None, help="""continue training from saved model at this path. Path must contain files saved by previous training process: 'config.pkl' : configuration;