This repository has been archived by the owner on May 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_mem_model.py
64 lines (52 loc) · 2.36 KB
/
test_mem_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from utils import get_logger, make_date_dir, find_latest_dir
from data_utils import load_agg_selected_data_mem, batch_loader
import numpy as np
from time import time
from AR_mem.config import Config
from AR_mem.model import Model
def main():
config = Config()
logger, log_dir = get_logger(os.path.join(config.model, "logs/"))
logger.info("=======Model Configuration=======")
logger.info(config.desc)
logger.info("=================================")
try:
_, _, test_x, _, _, test_y, _, _, test_m, test_dt = load_agg_selected_data_mem(data_path=config.data_path, \
x_len=config.x_len, \
y_len=config.y_len, \
foresight=config.foresight, \
cell_ids=config.test_cell_ids, \
dev_ratio=config.dev_ratio, \
test_len=config.test_len, \
seed=config.seed)
model = Model(config)
if config.latest_model:
model_dir = find_latest_dir(os.path.join(config.model, 'model_save/'))
else:
if not model_dir:
raise Exception("model_dir or latest_model=True should be defined in config")
model_dir = config.model_dir
model.restore_session(model_dir)
if len(test_y) > 100000:
# Batch mode
test_data = list(zip(test_x, test_m, test_y))
test_batches = batch_loader(test_data, config.batch_size)
total_pred = np.empty(shape=(0, test_y.shape[1]))
for batch in test_batches:
batch_x, batch_m, batch_y = zip(*batch)
pred, _, _, _, _ = model.eval(batch_x, batch_m, batch_y)
total_pred = np.r_[total_pred, pred]
else:
# Not batch mode
total_pred, test_loss, test_rse, test_smape, test_mae = model.eval(test_x, test_m, test_y)
result_dir = make_date_dir(os.path.join(config.model, 'results/'))
np.save(os.path.join(result_dir, 'pred.npy'), total_pred)
np.save(os.path.join(result_dir, 'test_y.npy'), test_y)
np.save(os.path.join(result_dir, 'test_dt.npy'), test_dt)
logger.info("Saving results at {}".format(result_dir))
logger.info("Testing finished, exit program")
except:
logger.exception("ERROR")
if __name__ == "__main__":
main()