-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path11-transformer.py
executable file
·28 lines (24 loc) · 1.15 KB
/
11-transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python3
"""Transformer class o create the decoder for a transformer"""
import tensorflow as tf
Encoder = __import__('9-transformer_encoder').Encoder
Decoder = __import__('10-transformer_decoder').Decoder
class Transformer(tf.keras.Model):
"""Transformer class"""
def __init__(self, N, dm, h, hidden, input_vocab, target_vocab,
max_seq_input, max_seq_target, drop_rate=0.1):
"""Class constructor"""
super().__init__()
self.encoder = Encoder(N, dm, h, hidden, input_vocab,
max_seq_input, drop_rate)
self.decoder = Decoder(N, dm, h, hidden, target_vocab,
max_seq_target, drop_rate)
self.linear = tf.keras.layers.Dense(target_vocab)
def call(self, inputs, target, training, encoder_mask,
look_ahead_mask, decoder_mask):
"""Call Method"""
enc_output = self.encoder(inputs, training, encoder_mask)
dec_output = self.decoder(target, enc_output, training,
look_ahead_mask, decoder_mask)
final_output = self.linear(dec_output)
return final_output