-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path9-transformer_encoder.py
executable file
·32 lines (28 loc) · 1.17 KB
/
9-transformer_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/env python3
"""Encoder class o create the encoder for a transformer"""
import tensorflow as tf
positional_encoding = __import__('4-positional_encoding').positional_encoding
EncoderBlock = __import__('7-transformer_encoder_block').EncoderBlock
class Encoder(tf.keras.layers.Layer):
"""Encoder class"""
def __init__(self, N, dm, h, hidden, input_vocab,
max_seq_len, drop_rate=0.1):
"""Class constructor"""
super().__init__()
self.N = N
self.dm = dm
self.embedding = tf.keras.layers.Embedding(input_vocab, dm)
self.positional_encoding = positional_encoding(max_seq_len, dm)
self.blocks = [EncoderBlock(dm, h, hidden, drop_rate)
for _ in range(N)]
self.dropout = tf.keras.layers.Dropout(drop_rate)
def call(self, x, training, mask):
"""Call Method"""
seq_len = x.shape[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.dm, tf.float32))
x += self.positional_encoding[:seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.N):
x = self.blocks[i](x, training, mask)
return x