-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathembeddings.py
80 lines (66 loc) · 2.85 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.nn.functional as F
class item(torch.nn.Module):
def __init__(self, config):
super(item, self).__init__()
self.num_rate = config.num_rate
self.num_genre = config.num_genre
self.num_director = config.num_director
self.num_actor = config.num_actor
self.embedding_dim = config.embedding_dim
self.embedding_rate = torch.nn.Embedding(
num_embeddings=self.num_rate,
embedding_dim=self.embedding_dim
)
self.embedding_genre = torch.nn.Linear(
in_features=self.num_genre,
out_features=self.embedding_dim,
bias=False
)
self.embedding_director = torch.nn.Linear(
in_features=self.num_director,
out_features=self.embedding_dim,
bias=False
)
self.embedding_actor = torch.nn.Linear(
in_features=self.num_actor,
out_features=self.embedding_dim,
bias=False
)
def forward(self, rate_idx, genre_idx, director_idx, actors_idx, vars=None):
rate_emb = self.embedding_rate(rate_idx)
genre_emb = self.embedding_genre(genre_idx.float()) / torch.sum(genre_idx.float(), 1).view(-1, 1)
director_emb = self.embedding_director(director_idx.float()) / torch.sum(director_idx.float(), 1).view(-1, 1)
actors_emb = self.embedding_actor(actors_idx.float()) / torch.sum(actors_idx.float(), 1).view(-1, 1)
return torch.cat((rate_emb, genre_emb, director_emb, actors_emb), 1)
class user(torch.nn.Module):
def __init__(self, config):
super(user, self).__init__()
self.num_gender = config.num_gender
self.num_age = config.num_age
self.num_occupation = config.num_occupation
self.num_zipcode = config.num_zipcode
self.embedding_dim = config.embedding_dim
self.embedding_gender = torch.nn.Embedding(
num_embeddings=self.num_gender,
embedding_dim=self.embedding_dim
)
self.embedding_age = torch.nn.Embedding(
num_embeddings=self.num_age,
embedding_dim=self.embedding_dim
)
self.embedding_occupation = torch.nn.Embedding(
num_embeddings=self.num_occupation,
embedding_dim=self.embedding_dim
)
self.embedding_area = torch.nn.Embedding(
num_embeddings=self.num_zipcode,
embedding_dim=self.embedding_dim
)
def forward(self, gender_idx, age_idx, occupation_idx, area_idx):
gender_emb = self.embedding_gender(gender_idx)
age_emb = self.embedding_age(age_idx)
occupation_emb = self.embedding_occupation(occupation_idx)
area_emb = self.embedding_area(area_idx)
return torch.cat((gender_emb, age_emb, occupation_emb, area_emb), 1)