-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_train_val.py
138 lines (104 loc) · 5.75 KB
/
make_train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import sys
import glob
import random
import shutil
import numpy as np
import argparse
from PIL import Image
from datetime import datetime
def list_image_counts(data_path):
img_dirs = glob.glob(os.path.join(data_path, '*'))
for d in img_dirs:
imgs = glob.glob(os.path.join(d,'*'))
print(str(len(imgs)).zfill(5) + ' --- ' + os.path.basename(d))
if __name__=="__main__":
parser = argparse.ArgumentParser(description='Make a train and val set from labels')
parser.add_argument('data_dir', metavar='data_dir', help='path to labelset')
parser.add_argument('output_dir', metavar='output_dir', help='path to the output train/val set')
parser.add_argument('--train_name', default='train', help='name of training dir to create')
parser.add_argument('--val_name', default='val', help='name of validation dir to create')
parser.add_argument('--train_pct', default=0, help='fraction of training vs validation from total')
parser.add_argument('--train_size', default=800, help='Number of images in train set per class')
parser.add_argument('--val_size', default=200, help='Number of images in val set per class')
parser.add_argument('--min_images', default=100, help='Smallest number of images per class to use')
parser.add_argument('--duplicate', action='store_false', help='When True, images are duplicated as needed')
parser.add_argument('--symlink', action='store_true', default=True, help='When True, symlink images instead of copying to new dir')
args = parser.parse_args()
img_path = args.data_dir
dataset_parent = args.output_dir
train_name = args.train_name
val_name = args.val_name
train_pct = float(args.train_pct)
train_size = int(args.train_size)
val_size = int(args.val_size)
min_imgs = int(args.min_images)
duplicate = args.duplicate
symflag = args.symlink
dataset_path = os.path.join(dataset_parent, datetime.utcnow().isoformat()[:-7].replace(':','-'))
if not os.path.exists(dataset_parent):
os.mkdir(dataset_parent)
if not os.path.exists(dataset_path):
os.makedirs(dataset_path)
os.makedirs(os.path.join(dataset_path,'train'))
os.makedirs(os.path.join(dataset_path,'val'))
img_dirs = sorted(glob.glob(os.path.join(img_path,'*')))
for img_dir in img_dirs:
print(img_dir)
imgs = sorted(glob.glob(os.path.join(img_dir,'*')))
if train_pct != 0:
all_inds = random.sample(range(0,len(imgs)),len(imgs))
train_inds = all_inds[0:int(train_pct*len(imgs))]
val_inds = all_inds[int(train_pct*len(imgs)):]
else:
if (len(imgs) < min_imgs):
print('not enough images, skipping class: ' + os.path.basename(img_dir))
continue
# make inds
# check if the number of images in the file is greater than the total needed for training
if duplicate and len(imgs) < (train_size + val_size):
# select with replacement if needed
if (train_size+val_size)-len(imgs) < len(imgs):
# if the difference is smaller than the original list, just randomly select the number needed
fill = np.random.choice(len(imgs), (train_size+val_size)-len(imgs), replace=False)
else:
# otherwise select with replacement
fill = np.random.choice(len(imgs), (train_size + val_size) - len(imgs), replace=True)
all_inds = np.block([np.arange(len(imgs)), fill]) # stack the filler on top to make appropriate dimension
np.random.shuffle(all_inds) # shuffle
#all_inds = np.random.choice(len(imgs), train_size+val_size, replace=duplicate)
train_inds = all_inds[0:train_size]
print('Number of unique train_inds: ' + str(len(set(train_inds))))
val_inds = all_inds[train_size:]
else:
# otherwise select without replacement
all_inds = np.random.choice(len(imgs), train_size+val_size, replace=False)
train_inds = all_inds[0:train_size]
print('Number of unique train_inds: ' + str(len(set(train_inds))))
val_inds = all_inds[train_size:]
new_img_dir = os.path.join(dataset_path,'train',os.path.basename(img_dir))
if not os.path.exists(new_img_dir):
os.makedirs(new_img_dir)
for i, ind in enumerate(train_inds):
img_dest = os.path.join(new_img_dir, os.path.basename(imgs[ind])[:-4]+'_index_'+str(i).zfill(4)+'.jpg')
if os.path.splitext(img_dest)[1] != os.path.splitext(imgs[ind])[1]:
img_dest = os.path.join(new_img_dir,
os.path.basename(imgs[ind])[:-4] + '_index_' +
str(i).zfill(4) + os.path.splitext(imgs[ind])[1])
if symflag:
os.symlink(imgs[ind], img_dest)
else:
shutil.copy(imgs[ind], img_dest)
new_img_dir = os.path.join(dataset_path, 'val', os.path.basename(img_dir))
if not os.path.exists(new_img_dir):
os.makedirs(new_img_dir)
for i, ind in enumerate(val_inds):
img_dest = os.path.join(new_img_dir, os.path.basename(imgs[ind])[:-4]+'_index_'+str(i).zfill(4)+'.jpg')
if os.path.splitext(img_dest)[1] != os.path.splitext(imgs[ind])[1]:
img_dest = os.path.join(new_img_dir,
os.path.basename(imgs[ind])[:-4] + '_index_' +
str(i).zfill(4) + os.path.splitext(imgs[ind])[1])
if symflag:
os.symlink(imgs[ind], img_dest)
else:
shutil.copy(imgs[ind], img_dest)