-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_eval.py
46 lines (37 loc) · 1.55 KB
/
test_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import json
import os
import time
import torch as t
import torch.utils.data.dataloader as DataLoader
import torchvision as tv
from model import Mydataloader
from model.myNetwork import MyCNN
from model.func import load_model
if __name__ == "__main__":
time_start = time.time()
config = json.load(open("config.json"))
DEVICE = t.device(config["DEVICE"])
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", default=config["GPU"], type=str, help="choose which DEVICE U want to use")
parser.add_argument("--epoch", default=0, type=int, help="The epoch to be tested")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
test_data = Mydataloader.TestingData()
test_loader = DataLoader.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=config["num_workers"])
criterian = t.nn.MSELoss()
model = MyCNN(n_channels=8).to(DEVICE)
# Test the train_loader
model = load_model(model, args.epoch)
model = model.eval()
with t.no_grad():
# Test the test_loader
for batch_idx, data in enumerate(test_loader):
data = data.to(DEVICE)
out = model(data)
DIR = 'result/test_result/epoch_{}'.format(args.epoch)
if not os.path.exists(DIR):
os.makedirs(DIR)
OUTPUT = t.cat([data, out], dim=3)
tv.transforms.ToPILImage()(OUTPUT.squeeze().cpu()).save('good_output.jpg')
# tv.transforms.ToPILImage()(OUTPUT.squeeze().cpu()).save(DIR + '/idx_{}.jpg'.format(batch_idx))