-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
61 lines (46 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Class used to run local experiments using GANs and different Netflow datasets.
"""
import os
from pathlib import Path
import json
import argparse
import tensorflow as tf
from anomaly_flow.data.netflow import NetFlowV2
from anomaly_flow.train.trainer_flow_nids import GANomaly
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
parser = argparse.ArgumentParser(description='Anomaly Flow - Netflow Experiments.')
parser.add_argument('--train-dataset', type=str, help="Dataset to train the models.", required=True)
parser.add_argument("--train_size", type=int, required=False)
parser.add_argument(
'--test-datasets', type=str,
nargs='+', help="List of datasets to test the trained model."
)
args = parser.parse_args()
hps = dict()
with open('hps.json', 'r', encoding='utf-8') as file:
hps = json.load(file)
if __name__ == "__main__":
netflow_dataset = NetFlowV2(args.train_dataset)
netflow_dataset.configure(
hps["batch_size"], 52, 1,
hps["shuffle_buffer_size"], True, True
)
netflow_trainer = GANomaly(
netflow_dataset, hps, tf.summary.create_file_writer("logs"), Path("log")
)
netflow_trainer.train(
hps["epochs"], hps["adversarial_loss_weight"], hps["contextual_loss_weight"],
hps["enc_loss_weight"], hps["step_log_frequency"]
)
netflow_trainer.test()
for dataset in args.test_datasets:
cross_eval_dataset = NetFlowV2(dataset)
cross_eval_dataset.configure(
hps["batch_size"], 52, 1,
hps["shuffle_buffer_size"], True, True
)
netflow_trainer.test(test_dataset=cross_eval_dataset.get_test_dataset(),
experiment_name=f"Cross_Evaluation_{args.train_dataset}-{dataset}")
del cross_eval_dataset