Skip to content

Commit 90d536e

Browse files
author
Lingjun Liu
committed
add examples
1 parent a48e1d3 commit 90d536e

File tree

9 files changed

+273
-20
lines changed

9 files changed

+273
-20
lines changed

docs/modules/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,8 @@ Seq2seq Luong Attention
5858
------------------------
5959

6060
.. autoclass:: Seq2seqLuongAttention
61+
62+
Transformer
63+
------------------------
64+
65+
.. autoclass:: Transformer
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
import tensorflow_datasets as tfds
3+
import tensorflow as tf
4+
import time
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
from tensorlayer.models.transformer import Transformer
8+
from tensorlayer.models.transformer.utils import metrics
9+
from tensorlayer.models.transformer.utils import attention_visualisation
10+
import tensorlayer as tl
11+
12+
13+
""" Translation from Portugese to English by Transformer model
14+
This tutorial provides basic instructions on how to define and train Transformer model on Tensorlayer for
15+
Translation task. You can also learn how to visualize the attention block via this tutorial.
16+
"""
17+
18+
def set_up_dataset():
19+
# Set up dataset for Portugese-English translation from the TED Talks Open Translation Project.
20+
# This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.
21+
# https://www.ted.com/participate/translate
22+
23+
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
24+
as_supervised=True)
25+
train_examples, val_examples = examples['train'], examples['validation']
26+
27+
# Set up tokenizer and save the tokenizer
28+
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
29+
(en.numpy() and pt.numpy() for pt, en in train_examples), target_vocab_size=2**14)
30+
31+
tokenizer.save_to_file("tokenizer")
32+
tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file("tokenizer")
33+
34+
return tokenizer, train_examples
35+
36+
37+
def test_tokenizer_success(tokenizer):
38+
sample_string = 'TensorLayer is awesome.'
39+
40+
tokenized_string = tokenizer.encode(sample_string)
41+
print ('Tokenized string is {}'.format(tokenized_string))
42+
43+
original_string = tokenizer.decode(tokenized_string)
44+
print ('The original string: {}'.format(original_string))
45+
assert original_string == sample_string
46+
47+
48+
49+
def generate_training_dataset(train_examples, tokenizer):
50+
def encode(lang1, lang2):
51+
lang1 = tokenizer.encode(
52+
lang1.numpy()) + [tokenizer.vocab_size+1]
53+
54+
lang2 = tokenizer.encode(
55+
lang2.numpy()) + [tokenizer.vocab_size+1]
56+
57+
return lang1, lang2
58+
MAX_LENGTH = 50
59+
def filter_max_length(x, y, max_length=MAX_LENGTH):
60+
return tf.logical_and(tf.size(x) <= max_length,
61+
tf.size(y) <= max_length)
62+
def tf_encode(pt, en):
63+
return tf.py_function(encode, [pt, en], [tf.int64, tf.int64])
64+
train_dataset = train_examples.map(tf_encode)
65+
train_dataset = train_dataset.filter(filter_max_length)
66+
# cache the dataset to memory to get a speedup while reading from it.
67+
train_dataset = train_dataset.cache()
68+
BUFFER_SIZE = 20000
69+
BATCH_SIZE = 64
70+
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(
71+
BATCH_SIZE, padded_shapes=([-1], [-1]))
72+
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
73+
74+
return train_dataset
75+
76+
77+
78+
79+
def model_setup(tokenizer):
80+
# define Hyper parameters for transformer
81+
class HYPER_PARAMS(object):
82+
vocab_size = tokenizer.vocab_size + 10
83+
encoder_num_layers = 4
84+
decoder_num_layers = 4
85+
hidden_size = 128
86+
ff_size = 512
87+
num_heads = 8
88+
keep_prob = 0.9
89+
90+
# Default prediction params
91+
extra_decode_length = 50
92+
beam_size = 5
93+
alpha = 0.6 # used to calculate length normalization in beam search
94+
95+
96+
label_smoothing=0.1
97+
learning_rate=2.0
98+
learning_rate_decay_rate=1.0
99+
learning_rate_warmup_steps=4000
100+
101+
sos_id = 0
102+
eos_id = tokenizer.vocab_size+1
103+
104+
105+
model = Transformer(HYPER_PARAMS)
106+
107+
# Set the optimizer
108+
learning_rate = CustomSchedule(HYPER_PARAMS.hidden_size, warmup_steps=HYPER_PARAMS.learning_rate_warmup_steps)
109+
optimizer = tl.optimizers.LazyAdamOptimizer(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
110+
return model, optimizer, HYPER_PARAMS
111+
112+
113+
# Use the Adam optimizer with a custom learning rate scheduler according to the formula in the Paper "Attention is All you need"
114+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
115+
def __init__(self, d_model, warmup_steps=5):
116+
super(CustomSchedule, self).__init__()
117+
118+
self.d_model = d_model
119+
self.d_model = tf.cast(self.d_model, tf.float32)
120+
121+
self.warmup_steps = warmup_steps
122+
123+
def __call__(self, step):
124+
arg1 = tf.math.rsqrt(step)
125+
arg2 = step * (self.warmup_steps ** -1.5)
126+
127+
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
128+
129+
130+
131+
def tutorial_transformer():
132+
tokenizer, train_examples = set_up_dataset()
133+
train_dataset = generate_training_dataset(train_examples, tokenizer)
134+
model, optimizer, HYPER_PARAMS = model_setup(tokenizer)
135+
136+
num_epochs = 10
137+
for epoch in range(num_epochs):
138+
model.train()
139+
for (batch, (inp, tar)) in enumerate(train_dataset):
140+
with tf.GradientTape() as tape:
141+
logits, weights_encoder, weights_decoder = model(inputs=inp, targets=tar)
142+
logits = metrics.MetricLayer(HYPER_PARAMS.vocab_size)([logits, tar])
143+
logits, loss = metrics.LossLayer(HYPER_PARAMS.vocab_size, 0.1)([logits, tar])
144+
grad = tape.gradient(loss, model.all_weights)
145+
optimizer.apply_gradients(zip(grad, model.all_weights))
146+
if (batch % 50 == 0):
147+
print('Batch ID {} at Epoch [{}/{}]: loss {:.4f}'.format(batch, epoch + 1, num_epochs, loss))
148+
149+
150+
151+
model.eval()
152+
sentence_en = tokenizer.encode('TensorLayer is awesome.')
153+
[prediction, weights_decoder], weights_encoder = model(inputs=[sentence_en])
154+
155+
predicted_sentence = tokenizer.decode([i for i in prediction["outputs"][0]
156+
if i < tokenizer.vocab_size])
157+
print("Translated: ", predicted_sentence)
158+
159+
160+
# visualize the self attention
161+
tokenizer_str = [tokenizer.decode([ts]) for ts in (sentence_en)]
162+
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], tokenizer_str, tokenizer_str)
163+
164+
165+
166+
167+
if __name__ == "__main__":
168+
tutorial_transformer()

tensorlayer/models/transformer/beamsearchHelper/beam_search.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def search(self, initial_ids, initial_cache):
3939
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
4040
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
4141

42-
# Account for corner case where there are no finished sequences for a
43-
# particular batch item. In that case, return alive sequences for that batch
44-
# item.
45-
finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
46-
finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
42+
# # Account for corner case where there are no finished sequences for a
43+
# # particular batch item. In that case, return alive sequences for that batch
44+
# # item.
45+
# finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
46+
# finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
4747
return finished_seq, finished_scores
4848

4949

tensorlayer/models/transformer/feedforward_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import tensorlayer as tl
2323

2424

25-
class FeedForwardLayer(tl.layers.Layer):
25+
class TransformerFeedForwardLayer(tl.layers.Layer):
2626
"""Fully connected feedforward network."""
2727

2828
def __init__(self, hidden_size, filter_size, keep_prob):
@@ -33,7 +33,7 @@ def __init__(self, hidden_size, filter_size, keep_prob):
3333
filter_size: int, filter size for the inner (first) dense layer.
3434
relu_dropout: float, dropout rate for training.
3535
"""
36-
super(FeedForwardLayer, self).__init__()
36+
super(TransformerFeedForwardLayer, self).__init__()
3737
self.hidden_size = hidden_size
3838
self.filter_size = filter_size
3939
self.relu_dropout = 1 - keep_prob

tensorlayer/models/transformer/transformer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tensorlayer.models import Model
2727
import tensorlayer.models.transformer.embedding_layer as embedding_layer
2828
from tensorlayer.models.transformer.attention_layer import SelfAttentionLayer, MultiHeadAttentionLayer
29-
from tensorlayer.models.transformer.feedforward_layer import FeedForwardLayer
29+
from tensorlayer.models.transformer.feedforward_layer import TransformerFeedForwardLayer
3030
from tensorlayer.models.transformer.utils.model_utils import positional_encoding
3131
from tensorlayer.models.transformer.utils.model_utils import get_decoder_self_attention_bias as get_target_mask
3232
from tensorlayer.models.transformer.utils.model_utils import get_padding_bias as get_input_mask
@@ -56,6 +56,8 @@ class Transformer(Model):
5656
>>> extra_decode_length = 5
5757
>>> beam_size = 1
5858
>>> alpha = 0.6
59+
>>> eos_id = 1
60+
>>> sos_id = 0
5961
>>> model = Transformer(TINY_PARAMS)
6062
6163
Returns
@@ -224,7 +226,7 @@ def decode(self, targets, encoder_outputs, attention_bias):
224226
decoder_inputs = self.embedding_softmax_layer(targets)
225227
with tf.name_scope("shift_targets"):
226228
# Shift targets to the right, and remove the last element
227-
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
229+
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]], constant_values=self.params.sos_id)[:, :-1, :]
228230
with tf.name_scope("add_pos_encoding"):
229231
length = tf.shape(decoder_inputs)[1]
230232
decoder_inputs += positional_encoding(length, self.params.hidden_size)
@@ -294,7 +296,7 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
294296
symbols_to_logits_fn, weights = self._get_symbols_to_logits_fn(max_decode_length)
295297

296298
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
297-
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
299+
initial_ids = tf.ones([batch_size], dtype=tf.int32)*self.params.sos_id
298300

299301
# Create cache storing decoder attention values for each layer.
300302
# pylint: disable=g-complex-comprehension
@@ -314,7 +316,7 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
314316
decoded_ids, scores = beam_search.sequence_beam_search(
315317
symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache,
316318
vocab_size=self.params.vocab_size, beam_size=self.params.beam_size, alpha=self.params.alpha,
317-
max_decode_length=max_decode_length, eos_id=1
319+
max_decode_length=max_decode_length, eos_id=self.params.eos_id
318320
)
319321

320322
# Get the top sequence for each batch element
@@ -421,7 +423,7 @@ def __init__(self, params):
421423
for _ in range(params.encoder_num_layers):
422424
# Create sublayers for each layer.
423425
self_attention_layer = SelfAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
424-
feed_forward_network = FeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
426+
feed_forward_network = TransformerFeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
425427

426428
self.layers.append(
427429
[
@@ -488,7 +490,7 @@ def __init__(self, params):
488490
for _ in range(params.decoder_num_layers):
489491
self_attention_layer = SelfAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
490492
enc_dec_attention_layer = MultiHeadAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
491-
feed_forward_network = FeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
493+
feed_forward_network = TransformerFeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
492494

493495
self.layers.append(
494496
[

tensorlayer/models/transformer/utils/attention_visualisation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def plot_attention_weights(attention, key, query):
1919
'''
2020

2121
fig = plt.figure(figsize=(16, 8))
22-
2322
attention = tf.squeeze(attention, axis=0)
2423

2524
for head in range(attention.shape[0]):

tensorlayer/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
"""
1111

1212
from .amsgrad import AMSGrad
13+
from .lazy_adam import LazyAdamOptimizer

tensorlayer/optimizers/lazy_adam.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Optimizer from addons and learning rate scheduler."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
24+
25+
class LazyAdamOptimizer(tf.optimizers.Adam):
26+
"""Variant of the Adam optimizer that handles sparse updates more efficiently.
27+
28+
The original Adam algorithm maintains two moving-average accumulators for
29+
each trainable variable; the accumulators are updated at every step.
30+
This class provides lazier handling of gradient updates for sparse
31+
variables. It only updates moving-average accumulators for sparse variable
32+
indices that appear in the current batch, rather than updating the
33+
accumulators for all indices. Compared with the original Adam optimizer,
34+
it can provide large improvements in model training throughput for some
35+
applications. However, it provides slightly different semantics than the
36+
original Adam algorithm, and may lead to different empirical results.
37+
Note, amsgrad is currently not supported and the argument can only be
38+
False.
39+
40+
This class is borrowed from:
41+
https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/lazy_adam.py
42+
"""
43+
44+
def _resource_apply_sparse(self, grad, var, indices):
45+
"""Applies grad for one step."""
46+
var_dtype = var.dtype.base_dtype
47+
lr_t = self._decayed_lr(var_dtype)
48+
beta_1_t = self._get_hyper('beta_1', var_dtype)
49+
beta_2_t = self._get_hyper('beta_2', var_dtype)
50+
local_step = tf.cast(self.iterations + 1, var_dtype)
51+
beta_1_power = tf.math.pow(beta_1_t, local_step)
52+
beta_2_power = tf.math.pow(beta_2_t, local_step)
53+
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
54+
lr = (lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power))
55+
56+
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
57+
m = self.get_slot(var, 'm')
58+
m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad
59+
60+
m_update_kwargs = {'resource': m.handle, 'indices': indices, 'updates': m_t_slice}
61+
m_update_op = tf.raw_ops.ResourceScatterUpdate(**m_update_kwargs)
62+
63+
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
64+
v = self.get_slot(var, 'v')
65+
v_t_slice = (beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square(grad))
66+
67+
v_update_kwargs = {'resource': v.handle, 'indices': indices, 'updates': v_t_slice}
68+
v_update_op = tf.raw_ops.ResourceScatterUpdate(**v_update_kwargs)
69+
70+
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
71+
var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t)
72+
73+
var_update_kwargs = {'resource': var.handle, 'indices': indices, 'updates': var_slice}
74+
var_update_op = tf.raw_ops.ResourceScatterSub(**var_update_kwargs)
75+
76+
return tf.group(*[var_update_op, m_update_op, v_update_op])

0 commit comments

Comments
 (0)