-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel_embeddings.py
31 lines (23 loc) · 955 Bytes
/
model_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
class ModelEmbeddings(nn.Module):
"""
Class that converts input words to their embeddings.
"""
def __init__(self, embed_size, vocab):
"""
Init the Embedding layers.
@param embed_size (int): Embedding size (dimensionality)
@param vocab (Vocab): Vocabulary object containing src and tgt languages
See vocab.py for documentation.
"""
super(ModelEmbeddings, self).__init__()
self.embed_size = embed_size
# default values
self.source = None
self.target = None
src_pad_token_idx = vocab.src['<pad>']
tgt_pad_token_idx = vocab.tgt['<pad>']
self.source = nn.Embedding(len(vocab.src), self.embed_size, padding_idx = src_pad_token_idx)
self.target = nn.Embedding(len(vocab.tgt), self.embed_size, padding_idx = tgt_pad_token_idx)