-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
64 lines (51 loc) · 2.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
from torch.utils.data import DataLoader
from config import Config
from component.ip_profiling import Profiler
from component.feature import FeatureExtractor
from component import rfre
from ml.dataset import RFREDataset
from ml.train import trainer, tester
from experiments.stats import save_stats
from utils import TQDM
def train(config: Config):
print("############################")
print("# start train phase #")
print("############################")
profiler = Profiler(config)
fe = FeatureExtractor()
for train_dir in config.train_dir_list:
profiler.profile(train_dir, benign_only=True)
for profile in TQDM(profiler, desc='extract features from profiles'):
fe.extract(profile)
rfre_feature_matrix = rfre.encode(fe.feature_matrix, fe.feature_list, config)
dataset = RFREDataset(rfre_feature_matrix, fe.profile_key_list)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
trainer(config, dataloader)
def test(config: Config):
print("############################")
print("# start test phase #######")
print("############################")
with open(config.path.model_path, 'rb') as f:
model = torch.load(f).eval()
for test_dir in config.test_dir_list:
print(f"\nCurrent testing -> \"{test_dir}\"")
profiler = Profiler(config)
profiler.profile(test_dir)
fe = FeatureExtractor()
for profile in TQDM(profiler, desc='extract features from profiles'):
fe.extract(profile)
rfre_feature_matrix = rfre.encode(fe.feature_matrix, fe.feature_list, config)
dataset = RFREDataset(rfre_feature_matrix, fe.profile_key_list)
dataloader = DataLoader(dataset, batch_size=config.batch_size)
result_csv = tester(model, dataloader)
result_csv_path = os.path.join(config.path.result_dir, os.path.split(test_dir)[1] + ".csv")
with open(result_csv_path, 'w') as f:
for line in result_csv:
f.write(line + "\n")
if __name__ == '__main__':
_config = Config('CICIDS2017')
train(_config)
test(_config)
save_stats(_config)