forked from hyeonthan/bcic4-2a-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
69 lines (49 loc) · 1.65 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Import libraries"""
import os, glob, yaml
from datetime import datetime
import numpy as np
from easydict import EasyDict
import pytz
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from sklearn.metrics import *
from dataloaders.bcic4a import BCICompet2aIV, BCICompet2aIV_TEST
from model.deepconvnet import DeepConvNet
# from dataloaders.preprocessing import preprocessing_vhdr, prepare_label
from utils.train import train
""" Config setting"""
CONFIG_PATH = f"{os.getcwd()}/configs"
filename = "config.yaml"
with open(f"{CONFIG_PATH}/{filename}") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
args = EasyDict(config)
# Set Device
if torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU_NUM
cudnn.benchmark = True
cudnn.fastest = True
cudnn.deteministic = True
args.lr = float(args.lr)
args.weight_decay = float(args.weight_decay)
# Set SEED
torch.manual_seed(args.SEED)
def main():
model = DeepConvNet().to(device=args.gpu)
args.train_mode = 'train'
train_data = BCICompet2aIV(args)
args.train_mode = 'validation'
validation_data = BCICompet2aIV(args)
args.train_mode = 'test'
test_data = BCICompet2aIV_TEST(args)
train_loader = DataLoader(train_data, shuffle=True, batch_size=256)
valid_loader = DataLoader(validation_data, shuffle=False, batch_size=1)
test_loader = DataLoader(test_data, shuffle=False, batch_size=1)
train(train_loader, valid_loader, test_loader, model, args)
if __name__ == "__main__":
import traceback
try:
main()
except Exception as e:
print(e)
print(traceback.format_exc())