-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
54 lines (37 loc) · 1.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from data_preprocessor import DataHandler
from models import MainModel
import matplotlib.pyplot as plt
import torch, time
INIT_TIME = time.time()
if __name__ == '__main__':
# a seed to decrease randomness
torch.manual_seed(100)
# taking some data
img_size = (80, 80)
img_handler = DataHandler('./data', size=img_size)
class_names = img_handler.class_names
train_data, test_data = img_handler.get_data(batch_size=20, test_size=0.08)
# creating/loading a model
model_path = './model_data/cnn_model2.pth'
model = MainModel(len(class_names), learning_rate=0.05, size=img_size, load_path=model_path)
# training a model
# model.train_loop(0, train_data, test_data)
# print(f'Training Time: {time.time() - INIT_TIME}')
# evaluating a model
acc, loss = model.test_loop(test_data)
print(f'Accuracy: [{acc*100:.2f}%] | Test Loss: [{loss:.4f}]')
# visualize
# train_history, test_history, accuracy = model.history
# fig, axs = plt.subplots(2)
# fig.set_size_inches(12, 9.5)
# axs[0].set_title('Total Loss')
# axs[0].plot(train_history, label='training loss')
# axs[0].plot(test_history, label='testing loss')
# axs[0].set_xlabel('Iterations Number')
# axs[0].set_ylabel('Loss per Epoch')
# axs[0].legend()
# axs[1].set_title('Total Testing Accuracy')
# axs[1].plot(accuracy)
# axs[1].set_xlabel('Iterations Number')
# axs[1].set_ylabel('Accuracy per Epoch')
# plt.show()