-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer_sub.py
86 lines (68 loc) · 2.37 KB
/
transformer_sub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from torch import nn
from attention import MultiHeadedAttention, PositionalEmbedding
from params import *
class EncoderLayer(nn.Module):
def __init__(self, input_dim, heads=4, dropout=0.0, hidden=2):
super().__init__()
self.hidden = hidden
self.attention = MultiHeadedAttention(input_dim=input_dim, heads=heads, mask=False)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
self.feedforward = nn.Sequential(
nn.Linear(input_dim, self.hidden * input_dim),
nn.ReLU(),
nn.Linear(self.hidden * input_dim, input_dim)
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
# multi-headed attention
mha = self.attention(x)
# layer normalization by attention and residual connection
x = self.norm1(mha + x)
# apply dropout
x = self.dropout1(x)
# pass through feedforward layer
ff = self.feedforward(x)
# apply layer normalization
x = self.norm2(ff + x)
# apply dropout
x = self.dropout2(x)
# return encoded embedding
return x
class DecoderLayer(nn.Module):
# TODO yet to be implemented!
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Encoder(nn.Module):
def __init__(self, embed_dim, heads, dropout, hidden, depth):
super().__init__()
self.depth = depth
self.embedding = nn.Embedding(VOCAB_SIZE, embed_dim)
self.position = PositionalEmbedding(embed_dim=embed_dim)
# apply encoder stacks
enc_block = []
for i in range(depth):
enc_block.append(
EncoderLayer(input_dim=embed_dim,
dropout=dropout,
hidden=hidden,
heads=heads
)
)
self.encoder_block = nn.Sequential(*enc_block)
def forward(self, x):
# convert sequence to embeddings + pe
x = self.embedding(x)
x = self.position(x)
# apply encoder stacks
x = self.encoder_block(x)
return x
class Decoder(nn.Module):
# TODO yet to be implemented!
def __init__(self):
super().__init__()
def forward(self, x):
return x