-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfm.py
112 lines (93 loc) · 3.48 KB
/
fm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from argparse import ArgumentParser
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Accuracy
from lit_data import LitDataModule
from lit_model import LitModel
from ctr import CTRDataset
class FactorizationMachine(nn.Module):
def __init__(self, feat_dims, embedding_dims, **kwargs):
super().__init__()
num_inputs = int(sum(feat_dims))
self.embedding = nn.Embedding(num_inputs, embedding_dims)
self.proj = nn.Embedding(num_inputs, 1)
self.fc = nn.Linear(1, 1)
for param in self.parameters():
try:
nn.init.xavier_normal_(param)
finally:
continue
def forward(self, x, return_logit=False):
v = self.embedding(x)
interaction = 1/2*(v.sum(1)**2 - (v**2).sum(1)).sum(-1, keepdims=True)
proj = self.proj(x).sum(1)
logit = self.fc(proj + interaction).flatten()
if return_logit:
return logit
else:
return torch.sigmoid(logit)
class LitFM(pl.LightningModule):
def __init__(self, lr=0.002, **kwargs):
super().__init__()
print("LitFM")
self.save_hyperparameters()
self.model = FactorizationMachine(**kwargs)
self.lr = lr
self.train_acc = Accuracy()
self.test_acc = Accuracy()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), self.lr, weight_decay=1e-5)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
ypred = self(x)
loss = F.binary_cross_entropy(ypred, y.to(torch.float32))
self.train_acc.update(ypred, y)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, y = batch
ypred = self(x)
loss = F.binary_cross_entropy(ypred, y.to(torch.float32))
self.test_acc.update(ypred, y)
return {"loss": loss}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
acc = self.train_acc.compute()
self.train_acc.reset()
self.logger.experiment.add_scalar(
"train/loss", avg_loss, self.current_epoch)
self.logger.experiment.add_scalar(
"train/acc", acc, self.current_epoch)
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
acc = self.test_acc.compute()
self.test_acc.reset()
self.logger.experiment.add_scalar(
"val/loss", avg_loss, self.current_epoch)
self.logger.experiment.add_scalar(
"val/acc", acc, self.current_epoch)
def main(args):
data = LitDataModule(
CTRDataset(),
batch_size=args.batch_size,
num_workers=3,
prefetch_factor=4)
data.setup()
model = LitFM(
feat_dims=data.dataset.feat_dims,
embedding_dims=args.embedding_dims)
logger = TensorBoardLogger("lightning_logs", name=f"FM_{args.embedding_dims}")
trainer = pl.Trainer.from_argparse_args(args, logger=logger)
trainer.fit(model, data)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--embedding_dims", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=1024)
pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)