-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_callback.py
31 lines (21 loc) · 1009 Bytes
/
predict_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import sys
import keras.callbacks
import eval
class PredictCallback(keras.callbacks.Callback):
def __init__(self, reader, config, name):
super().__init__()
self.path = config.results_path + '/' + name + "/epoch"
self.generator = reader
def on_epoch_end(self, epoch, logs=None):
originals, processed, targets = self.generator.get_sample_cases()
predictions = self.model.predict(processed)
output_dir = self.path + str(epoch) + '/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for i, original in enumerate(originals):
ex = eval.visualize(original, predictions[i], targets[i])
ex.save(output_dir + str(i) + '.jpg', 'JPEG')
# crf_predictions = process_crf(processed[i, :, :, :], predictions[i, :, :, :]) # (224, 224, 4)
# crf_ex = eval.visualize(original, crf_predictions, targets[i])
# crf_ex.save(output_dir + str(i) + '_crf.jpg', 'JPEG')