-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path8-transformer_decoder_block.py
executable file
·40 lines (36 loc) · 1.76 KB
/
8-transformer_decoder_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
"""DecoderBlock class to create an encoder block for a transformer"""
import tensorflow as tf
MultiHeadAttention = __import__('6-multihead_attention').MultiHeadAttention
class DecoderBlock(tf.keras.layers.Layer):
"""DecoderBlock class"""
def __init__(self, dm, h, hidden, drop_rate=0.1):
"""Class constructor"""
super().__init__()
self.mha1 = MultiHeadAttention(dm, h)
self.mha2 = MultiHeadAttention(dm, h)
self.dense_hidden = tf.keras.layers.Dense(hidden, activation='relu')
self.dense_output = tf.keras.layers.Dense(dm)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(drop_rate)
self.dropout2 = tf.keras.layers.Dropout(drop_rate)
self.dropout3 = tf.keras.layers.Dropout(drop_rate)
def call(self, x, encoder_output, training, look_ahead_mask, padding_mask):
"""Call Method"""
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2, attn_weights_block2 = self.mha2(out1, encoder_output,
encoder_output, padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1)
ffn = tf.keras.Sequential([
self.dense_hidden,
self.dense_output
])
ffn_output = ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2)
return out3