Skip to content

Commit 2875749

Browse files
committed
add Factorization Machines
1 parent 27de5e8 commit 2875749

File tree

4 files changed

+133
-19
lines changed

4 files changed

+133
-19
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Download these dataset and unzip in the root of this directory:
1919
5. Personalized Ranking for Recommender Systems: [`utils.py`](utils.py). For now, only the BRP loss function is implemented.
2020
6. Neural Collaborative Filtering for Personalized Ranking: [`neumf.py`](neumf.py)
2121
7. Sequence-Aware Recommender Systems: [`caser.py`](caser.py)
22-
8. Feature-Rich Recommender Systems:
23-
9. Factorization Machines:
22+
8. Feature-Rich Recommender Systems: [`ctr.py`](ctr.py) contains dataloaders for CTR dataset
23+
9. Factorization Machines: [`fm.py`](fm.py)
2424
10. Deep Factorization Machines:
2525

2626
## Usage

ctr.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def csv_reader(data_path):
1717

1818

1919
class CTRDataset(BaseDataset):
20-
def __init__(self, data_dir, min_threshold=4):
20+
def __init__(self, data_dir="./ctr", min_threshold=4):
2121
"""Read CTR dataset from train.csv and test.csv
2222
2323
Parameters
@@ -38,6 +38,7 @@ def __init__(self, data_dir, min_threshold=4):
3838
col: self.train_df[col].value_counts()
3939
for col in self.feat_cols}
4040
# Feature mapper maps a unique encoded value to an identifier
41+
# So each value is considered to be a categorical value.
4142
# Unique values are filtered with occurence greater or equal to min_threshold
4243
# A default value will be assign to values that not defined in feature mapper
4344
self.feat_mapper = {}
@@ -53,38 +54,43 @@ def _constant_factory(v):
5354
# Feature dimension = number of unique values = number of values in mapper + defaults
5455
self.feat_dims = np.array([len(mapper) + 1
5556
for mapper in self.feat_mapper.values()])
56-
# Offset is a value add to the whole field to discriminate column order
57-
self.offsets = np.array((0, *np.cumsum(self.feat_dims).tolist()[:-1]))
57+
# Offset is a value add to the whole field to discriminate values in different columns
58+
self.offsets = np.array((0, *np.cumsum(self.feat_dims).tolist()[:-1])).astype(np.int32)
5859
# Map values in dataframe
5960
for col, mapper in self.feat_mapper.items():
6061
self.train_df[col] = self.train_df[col].map(mapper)
6162
self.test_df[col] = self.test_df[col].map(mapper)
6263
# For each split
63-
self.getitem_df = None
64+
self.X = None
65+
self.y = None
66+
67+
def build_items(self, train=True):
68+
if train:
69+
df = self.train_df
70+
else:
71+
df = self.test_df
72+
self.X = df[self.feat_cols].values + self.offsets
73+
self.y = df[0].values
6474

6575
def split(self, *args, **kwargs) -> Tuple[BaseDataset, BaseDataset]:
6676
train_split = deepcopy(self)
67-
del train_split.test_df
68-
train_split.getitem_df = self.train_df
77+
train_split.build_items(True)
6978

7079
test_split = deepcopy(self)
71-
del test_split.train_df
72-
test_split.getitem_df = self.test_df
80+
test_split.build_items(False)
81+
7382
return train_split, test_split
7483

7584
def __len__(self):
76-
assert self.getitem_df is not None
77-
return len(self.getitem_df)
85+
assert self.X is not None and self.y is not None
86+
return len(self.X)
7887

7988
def __getitem__(self, idx):
80-
assert self.getitem_df is not None
81-
x = self.getitem_df[self.feat_cols].iloc[idx].values + self.offsets
82-
y = self.getitem_df[0].iloc[idx]
83-
return x, y
89+
assert self.X is not None and self.y is not None
90+
return self.X[idx], self.y[idx]
8491

8592

8693
if __name__ == "__main__":
8794
data = CTRDataset("ctr")
8895
train_split, test_split = data.split()
8996
print(train_split[0])
90-
print(test_split[0])

fm.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from argparse import ArgumentParser
2+
3+
import numpy as np
4+
import pytorch_lightning as pl
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from pytorch_lightning.loggers import TensorBoardLogger
9+
from torchmetrics import Accuracy
10+
11+
from lit_data import LitDataModule
12+
from lit_model import LitModel
13+
from ctr import CTRDataset
14+
15+
16+
class FactorizationMachine(nn.Module):
17+
def __init__(self, feat_dims, embedding_dims):
18+
super().__init__()
19+
num_inputs = int(sum(feat_dims))
20+
self.embedding = nn.Embedding(num_inputs, embedding_dims)
21+
self.proj = nn.Embedding(num_inputs, 1)
22+
self.fc = nn.Linear(1, 1)
23+
for param in self.parameters():
24+
try:
25+
nn.init.xavier_normal_(param)
26+
finally:
27+
continue
28+
29+
def forward(self, x):
30+
v = self.embedding(x)
31+
interaction = 1/2*(v.sum(1)**2 - (v**2).sum(1)).sum(-1, keepdims=True)
32+
proj = self.proj(x).sum(1)
33+
logit = self.fc(proj + interaction)
34+
return torch.sigmoid(logit).flatten()
35+
36+
37+
class LitFM(pl.LightningModule):
38+
def __init__(self, lr=0.002, **kwargs):
39+
super().__init__()
40+
self.save_hyperparameters()
41+
self.model = FactorizationMachine(**kwargs)
42+
self.lr = lr
43+
self.train_acc = Accuracy()
44+
self.test_acc = Accuracy()
45+
46+
def configure_optimizers(self):
47+
return torch.optim.Adam(self.parameters(), self.lr, weight_decay=1e-5)
48+
49+
def forward(self, x):
50+
return self.model(x)
51+
52+
def training_step(self, batch, batch_idx):
53+
x, y = batch
54+
ypred = self(x)
55+
loss = F.binary_cross_entropy(ypred, y.to(torch.float32))
56+
self.train_acc.update(ypred, y)
57+
return {"loss": loss}
58+
59+
def validation_step(self, batch, batch_idx):
60+
x, y = batch
61+
ypred = self(x)
62+
loss = F.binary_cross_entropy(ypred, y.to(torch.float32))
63+
self.test_acc.update(ypred, y)
64+
return {"loss": loss}
65+
66+
def training_epoch_end(self, outputs):
67+
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
68+
acc = self.train_acc.compute()
69+
self.train_acc.reset()
70+
self.logger.experiment.add_scalar(
71+
"train/loss", avg_loss, self.current_epoch)
72+
self.logger.experiment.add_scalar(
73+
"train/acc", acc, self.current_epoch)
74+
75+
def validation_epoch_end(self, outputs):
76+
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
77+
acc = self.test_acc.compute()
78+
self.test_acc.reset()
79+
self.logger.experiment.add_scalar(
80+
"val/loss", avg_loss, self.current_epoch)
81+
self.logger.experiment.add_scalar(
82+
"val/acc", acc, self.current_epoch)
83+
84+
85+
def main(args):
86+
data = LitDataModule(
87+
CTRDataset(),
88+
batch_size=args.batch_size,
89+
num_workers=3,
90+
prefetch_factor=4)
91+
data.setup()
92+
93+
model = LitFM(
94+
feat_dims=data.dataset.feat_dims,
95+
embedding_dims=args.embedding_dims)
96+
97+
logger = TensorBoardLogger("lightning_logs", name=f"FM_{args.embedding_dims}")
98+
trainer = pl.Trainer.from_argparse_args(args, logger=logger)
99+
trainer.fit(model, data)
100+
101+
102+
if __name__ == "__main__":
103+
parser = ArgumentParser()
104+
parser.add_argument("--embedding_dims", type=int, default=20)
105+
parser.add_argument("--batch_size", type=int, default=1024)
106+
pl.Trainer.add_argparse_args(parser)
107+
args = parser.parse_args()
108+
main(args)

lit_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __init__(self, dataset: BaseDataset,
4343
}
4444

4545
def setup(self):
46-
self.num_users = self.dataset.num_users
47-
self.num_items = self.dataset.num_items
46+
self.num_users = getattr(self.dataset, "num_users", None)
47+
self.num_items = getattr(self.dataset, "num_items", None)
4848
self.train_split, self.test_split = self.dataset.split(
4949
self.train_ratio)
5050

0 commit comments

Comments
 (0)