-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlogreg_train.py
101 lines (76 loc) · 2.63 KB
/
logreg_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Script to train one-vs-all logistic regression
It saves models weights in weights.pt
"""
import numpy as np
import pandas as pd
from time import time
from argparse import ArgumentParser
from matplotlib import pyplot as plt
from config import Config
from dslr.preprocessing import scale, fill_na
from dslr.multi_classifier import OneVsAllLogisticRegression
def plot_training(model: OneVsAllLogisticRegression):
"""
Plot loss history
:param model: trained model
:return: None
"""
_, ax = plt.subplots()
epochs = range(1, model.epochs + 1)
for sub_model, label in zip(model.models, model.labels):
ax.plot(epochs, sub_model.hist, label=label)
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.set_title('Logistic Regression, batch size: {}'
.format(model.batch_size))
ax.legend(loc="upper right")
plt.show()
def train(data_path: str,
weights_path: str,
config_path: str,
v: bool = False):
# CHOOSE FROM CONFIG FEATURES TO TRAIN ON
config = Config(config_path)
courses = config.choosed_features()
# READ TRAIN DATASET AND FILL NAN VALUES
preparation_t = time()
df = pd.read_csv(data_path)
df = fill_na(df, courses)
# CHOOSE FEATURE AND LABEL VALUES
x = df[courses].values
y = df["Hogwarts House"].values
# CREATE MODEL TO TRAIN
model = OneVsAllLogisticRegression(
device=config.device,
transform=scale[config.scale],
lr=config.lr,
epochs=config.epochs,
batch_size=config.batch_size,
seed=config.seed,
save_hist=v
)
preparation_t = time() - preparation_t
# TRAIN MODEL
train_t = time()
model.fit(x, y)
train_t = time() - train_t
# SAVE WEIGHTS AND SCALE PARAMS
model.save(weights_path)
print("Preparation time:", np.round(preparation_t, 4))
print("Training time:", np.round(train_t, 4))
print("All time:", np.round(preparation_t + train_t, 4))
if v:
plot_training(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('data_path', type=str,
help='Path to "dataset_train.csv" file')
parser.add_argument('--weights_path', type=str, default="data/weights.pt",
help='Path to save weights file')
parser.add_argument('--config_path', type=str, default="config.yaml",
help='path to .yaml file')
parser.add_argument('-v', action="store_true",
help='visualize training')
args = parser.parse_args()
train(args.data_path, args.weights_path, args.config_path, args.v)