-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathmodel.py
54 lines (43 loc) · 2.24 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingModel(nn.Module):
def __init__(self, vocab_size, embed_size):
super(EmbeddingModel, self).__init__()
self.vocab_size = vocab_size
self.embed_size = embed_size
self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)
self.init_embed()
def init_embed(self):
init = 0.5 / self.embed_size
self.in_embed.weight.data.uniform_(-init, init)
self.out_embed.weight.data.uniform_(-0, 0)
def forward(self, input_labels, pos_labels, neg_labels):
''' input_labels: center words, [batch_size] which is one dimentional vector of batch size
pos_labels: positive words, [batch_size, (window_size * 2)]
neg_labels:negative words, [batch_size, K]
return: loss, [batch_size]
'''
input_embedding = self.in_embed(input_labels)# [batch_size, embed_size]
pos_embedding = self.out_embed(pos_labels)# [batch_size, (window * 2), embed_size]
neg_embedding = self.out_embed(neg_labels)# [batch_size, K, embed_size]
input_embedding = input_embedding.unsqueeze(2)# [batch_size, embed_size, 1]
pos_dot = torch.bmm(pos_embedding, input_embedding)# [batch_size, (window * 2), 1]
pos_dot = pos_dot.squeeze(2)# [batch_size, (window * 2)]
neg_dot = torch.bmm(neg_embedding, -input_embedding)# [batch_size, K, 1]
neg_dot = neg_dot.squeeze(2)# [batch_size, K]
log_pos = F.logsigmoid(pos_dot).sum(1)# [batch_size]
log_neg = F.logsigmoid(neg_dot).sum(1)# [batch_size]
return -log_pos-log_neg # [batch_size]
def save_embedding(self, outdir, idx2word):
embeds = self.in_embed.weight.data.cpu().numpy()
f1 = open(os.path.join(outdir, 'vec.tsv'), 'w')
f2 = open(os.path.join(outdir, 'word.tsv'), 'w')
for idx in range(len(embeds)):
word = idx2word[idx]
embed = '\t'.join([str(x) for x in embeds[idx]])
f1.write(embed+'\n')
f2.write(word+'\n')