-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder_model.py
46 lines (43 loc) · 1.32 KB
/
encoder_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from params import *
from torch import nn
from transformer_sub import Encoder
import torch
class EncoderModel(nn.Module):
"""""
Encoder classification model
for sentiment analysis.
"""""
def __init__(self,
nr_embed=VOCAB_SIZE,
embed_dim=EMBED_DIM,
output_dim=CLS,
pool_type=POOL,
attention=ATTENTION,
heads=HEADS,
dropout=DROPOUT,
hidden=HIDDEN,
depth=ENCDEPTH):
super().__init__()
self.embed_dim = embed_dim
self.output_dim = output_dim
self.nr_embed = nr_embed
self.attention = attention
self.heads = heads
self.pool_type = pool_type
self.dropout = dropout
self.depth = depth
self.encoder = Encoder(embed_dim=embed_dim,
heads=heads,
dropout=dropout,
hidden=hidden,
depth=depth)
self.out = nn.Linear(in_features=embed_dim, out_features=output_dim)
def forward(self, x):
# Encoder block
x = self.encoder(x)
# Pooling
x = torch.mean(x, dim=1)
# class probabilities
x = self.out(x)
# output
return x