-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_tfrecords.py
225 lines (189 loc) · 7.63 KB
/
create_tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import argparse
import os
import random
import sys
import threading
import numpy as np
import tensorflow as tf
from datetime import datetime
from queue import Queue
from glob import glob
def _bytes_feature(value):
"""
:param value:
:return:
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(
image_buffer,
mask_buffer,
image_filename,
mask_filename):
"""
:param image_buffer:
:param mask_buffer:
:param filename:
:param mask_filename:
:return:
"""
image_filename = os.path.split(image_filename)[-1]
mask_filename = os.path.split(mask_filename)[-1]
example = tf.train.Example(features=tf.train.Features(feature={
'image': _bytes_feature(image_buffer),
'mask': _bytes_feature(mask_buffer),
"image_filename": _bytes_feature(bytes(image_filename, encoding="UTF-8")),
"mask_filename": _bytes_feature(bytes(mask_filename, encoding="UTF-8"))
}))
return example
def _process_image_files_batch(
thread_index,
batch_data,
shards,
total_shards,
mask_dir,
mask_suffix,
output_dir,
output_name,
error_queue):
"""
:param thread_index:
:param batch_data:
:param shards:
:param total_shards:
:param mask_dir:
:param mask_suffix:
:param output_dir:
:param output_name:
:param error_queue:
:return:
"""
batch_size = len(batch_data)
batch_per_shard = batch_size // len(shards)
counter = 0
error_counter = 0
for s in range(len(shards)):
shard = shards[s]
output_filename = '%s-%.5d-of-%.5d' % (output_name, shard, total_shards)
output_file = os.path.join(output_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
shard_range = (s * batch_per_shard, min(batch_per_shard, batch_size - (s * batch_per_shard)))
files_in_shard = np.arange(shard_range[0], shard_range[1], dtype=int)
for i in files_in_shard:
image_filename = data[i]
mask_filename = os.path.join(
mask_dir,
os.path.splitext(os.path.split(image_filename)[-1])[0] + mask_suffix)
try:
image_buffer = tf.gfile.FastGFile(image_filename, 'rb').read()
if not os.path.exists(mask_filename):
mask_filename = os.path.join(
mask_dir,
os.path.splitext(os.path.split(image_filename)[-1])[0] + ".jpg")
if not os.path.exists(mask_filename):
continue
mask_buffer = tf.gfile.FastGFile(mask_filename, 'rb').read()
example = _convert_to_example(image_buffer, mask_buffer, image_filename, mask_filename)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
except StopIteration as e:
error_counter += 1
error_msg = repr(e)
error_queue.put(error_msg)
print('%s [thread %d]: Wrote %d images to %s, with %d errors.' %
(datetime.now(), thread_index, shard_counter, output_file, error_counter))
sys.stdout.flush()
print('%s [thread %d]: Wrote %d images to %d shards, with %d errors.' %
(datetime.now(), thread_index, counter, len(shards), error_counter))
sys.stdout.flush()
def create(
data,
mask_dir,
mask_suffix,
output_name,
output_dir,
num_shards,
num_threads):
"""
:param data:
:param mask_dir:
:param mask_suffix:
:param output_name:
:param output_dir:
:param num_shards:
:param num_threads:
:return:
"""
num_data_per_thread = len(data) // num_threads
num_shard_per_thread = num_shards // num_threads
batch_data_ranges = [(i * num_data_per_thread, min(num_data_per_thread, len(data) - i * num_data_per_thread))
for i in range(num_threads)]
coord = tf.train.Coordinator()
error_queue = Queue()
threads = []
for thread_index in range(1, num_threads + 1):
batch_ranges = batch_data_ranges[thread_index - 1]
batch_data = data[batch_ranges[0]:batch_ranges[1]]
shards = [thread_index + (thread_index - 1) * (num_shard_per_thread - 1) + shard
for shard in range(num_shard_per_thread)]
args = (
thread_index,
batch_data,
shards,
num_shards,
mask_dir,
mask_suffix,
output_name,
output_dir,
error_queue)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
coord.join(threads)
errors = []
while not error_queue.empty():
errors.append(error_queue.get())
print('%d examples failed.' % (len(errors),))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', help='Directory of the images.', type=str, required=True)
parser.add_argument('--mask_dir', help='Directory of the masks.', type=str, required=True)
parser.add_argument('--image_suffix', help='Suffix of the images.', type=str, default=".jpg")
parser.add_argument('--mask_suffix', help='Suffix of the masks.', type=str, default="_mask.jpg")
parser.add_argument('--train_size', help='Ratio of training samples.', type=float, required=True)
parser.add_argument('--validation_size', help='Ratio of validation samples.', type=float, required=True)
parser.add_argument('--output_dir', help='Directory for the tfrecords.', type=str, default="./")
parser.add_argument('--shards', help='Number of shards to make.', type=int, default=1)
parser.add_argument('--threads', help='Number of threads to use.', type=int, default=1)
parser.add_argument('--shuffle', help='Shuffle the samples.', action='store_true', default=True)
parsed_args = parser.parse_args()
return parsed_args
def assert_args(args):
assert os.path.exists(args.image_dir), "Images directory does not exist"
assert os.path.exists(args.mask_dir), "Mask directory does not exist"
assert args.train_size + args.validation_size <= 1, "Train ratio + validation ratio must be <= 1"
assert args.train_size > 0, "Train ratio must be > 0"
assert args.validation_size >= 0, "Validation ratio must be >= 0"
assert args.shards > 0, "Number of shards must be > 0"
assert args.threads > 0, "Number of threads must be > 0"
assert args.shards % args.threads == 0
# assert len(glob(args.image_dir + "/*" + args.image_suffix)) == \
# len(glob(args.mask_dir + "/*" + args.mask_suffix)), "Number of images and masks does not match"
if __name__ == '__main__':
args = parse_args()
assert_args(args)
data = glob(os.path.join(args.image_dir, "*" + args.image_suffix))
if args.shuffle:
random.shuffle(data)
num_data = len(data)
num_train = round(num_data * args.train_size)
num_validation = round(num_data * args.validation_size)
training = data[:num_train]
create(training, args.mask_dir, args.mask_suffix, args.output_dir, "train", args.shards, args.threads)
if args.validation_size > 0:
validation = data[num_train:num_train + num_validation]
create(validation, args.mask_dir, args.mask_suffix, args.output_dir, "val", args.shards, args.threads)
if args.train_size + args.validation_size < 1:
test = data[num_train + num_validation:]
create(test, args.mask_dir, args.mask_suffix, args.output_dir, "test", args.shards, args.threads)