diff --git a/hyperparams.py b/hyperparams.py index 99ba046..51e06fc 100644 --- a/hyperparams.py +++ b/hyperparams.py @@ -27,7 +27,7 @@ class Hyperparams: num_epochs = 20 num_heads = 8 dropout_rate = 0.1 - sinusoid = False # If True, use sinusoid. If false, positional embedding. + sinusoid = True # If True, use sinusoid. If false, positional embedding. diff --git a/modules.py b/modules.py index 4222d0a..127bb44 100644 --- a/modules.py +++ b/modules.py @@ -4,16 +4,19 @@ June 2017 by kyubyong park. kbpark.linguist@gmail.com. https://www.github.com/kyubyong/transformer + ''' from __future__ import print_function import tensorflow as tf - +import numpy as np def normalize(inputs, epsilon = 1e-8, scope="ln", reuse=None): - '''Applies layer normalization. + ''' + + Applies layer normalization. Args: inputs: A tensor with 2 or more dimensions, where the first dimension has @@ -45,7 +48,8 @@ def embedding(inputs, scale=True, scope="embedding", reuse=None): - '''Embeds a given tensor. + ''' + Embeds a given tensor. Args: inputs: A `Tensor` with type `int32` or `int64` containing the ids @@ -153,15 +157,20 @@ def positional_encoding(inputs, # Convert to a tensor lookup_table = tf.convert_to_tensor(position_enc) - + lookup_table= tf.cast(lookup_table, tf.float32) + if zero_pad: lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), lookup_table[1:, :]), 0) + + outputs = tf.nn.embedding_lookup(lookup_table, position_ind) - + + if scale: outputs = outputs * num_units**0.5 - + + return outputs @@ -260,7 +269,8 @@ def feedforward(inputs, num_units=[2048, 512], scope="multihead_attention", reuse=None): - '''Point-wise feed forward net. + ''' + Point-wise feed forward net. Args: inputs: A 3d tensor with shape of [N, T, C]. diff --git a/train.py b/train.py index 45a5732..7e6cecc 100644 --- a/train.py +++ b/train.py @@ -88,7 +88,6 @@ def __init__(self, is_training=True): ## Positional Encoding if hp.sinusoid: self.dec += positional_encoding(self.decoder_inputs, - vocab_size=hp.maxlen, num_units=hp.hidden_units, zero_pad=False, scale=False,