Skip to content

Commit 5379a25

Browse files
committed
Added function to crop detections and save them as new images
1 parent dc2b067 commit 5379a25

File tree

6 files changed

+62
-4
lines changed

6 files changed

+62
-4
lines changed

core/functions.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,25 @@ def count_objects(data, by_class = False):
2828
else:
2929
counts['total object'] = num_objects
3030

31-
return counts
31+
return counts
32+
33+
# function for cropping each detection and saving as new image
34+
def crop_objects(img, data, path, allowed_classes = None):
35+
boxes, scores, classes, num_objects = data
36+
class_names = read_class_names(cfg.YOLO.CLASSES)
37+
#create dictionary to hold count of objects for image name
38+
counts = dict()
39+
for i in range(num_objects):
40+
# get count of class for part of image name
41+
class_index = int(classes[i])
42+
class_name = class_names[class_index]
43+
counts[class_name] = counts.get(class_name, 0) + 1
44+
# get box coords
45+
xmin, ymin, xmax, ymax = boxes[i]
46+
# crop detection from image (take an additional 5 pixels around all edges)
47+
cropped_img = img[int(ymin)-5:int(ymax)+5, int(xmin)-5:int(xmax)+5]
48+
# construct image name and join it to path for saving crop properly
49+
img_name = class_name + '_' + str(counts[class_name]) + '.png'
50+
img_path = os.path.join(path, img_name )
51+
# save image
52+
cv2.imwrite(img_path, cropped_img)

detect.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
flags.DEFINE_list('images', './data/images/kite.jpg', 'path to input image')
2727
flags.DEFINE_string('output', './detections/', 'path to output folder')
2828
flags.DEFINE_float('iou', 0.45, 'iou threshold')
29-
flags.DEFINE_float('score', 0.25, 'score threshold')
29+
flags.DEFINE_float('score', 0.50, 'score threshold')
3030
flags.DEFINE_boolean('count', False, 'count objects within images')
3131
flags.DEFINE_boolean('dont_show', False, 'dont show image output')
3232
flags.DEFINE_boolean('info', False, 'print info on detections')
33+
flags.DEFINE_boolean('crop', False, 'crop detections from images')
3334

3435
def main(_argv):
3536
config = ConfigProto()
@@ -52,6 +53,10 @@ def main(_argv):
5253

5354
image_data = cv2.resize(original_image, (input_size, input_size))
5455
image_data = image_data / 255.
56+
57+
# get image name by using split method
58+
image_name = image_path.split('/')[-1]
59+
image_name = image_name.split('.')[0]
5560

5661
images_data = []
5762
for i in range(1):
@@ -95,6 +100,15 @@ def main(_argv):
95100
# hold all detection data in one variable
96101
pred_bbox = [bboxes, scores.numpy()[0], classes.numpy()[0], valid_detections.numpy()[0]]
97102

103+
# if crop flag is enabled, crop each detection and save it as new image
104+
if FLAGS.crop:
105+
crop_path = os.path.join(os.getcwd(), 'detections', 'crop', image_name)
106+
try:
107+
os.mkdir(crop_path)
108+
except FileExistsError:
109+
pass
110+
crop_objects(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB), pred_bbox, crop_path)
111+
98112
if FLAGS.count:
99113
# count objects found
100114
counted_classes = count_objects(pred_bbox, by_class = False)

detect_video.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from absl.flags import FLAGS
1111
import core.utils as utils
1212
from core.yolov4 import filter_boxes
13-
from core.functions import count_objects
13+
from core.functions import *
1414
from tensorflow.python.saved_model import tag_constants
1515
from PIL import Image
1616
import cv2
@@ -32,6 +32,7 @@
3232
flags.DEFINE_boolean('count', False, 'count objects within video')
3333
flags.DEFINE_boolean('dont_show', False, 'dont show video output')
3434
flags.DEFINE_boolean('info', False, 'print info on detections')
35+
flags.DEFINE_boolean('crop', False, 'crop detections from images')
3536

3637
def main(_argv):
3738
config = ConfigProto()
@@ -40,7 +41,9 @@ def main(_argv):
4041
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
4142
input_size = FLAGS.size
4243
video_path = FLAGS.video
43-
44+
# get video name by using split method
45+
video_name = video_path.split('/')[-1]
46+
video_name = video_name.split('.')[0]
4447
if FLAGS.framework == 'tflite':
4548
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
4649
interpreter.allocate_tensors()
@@ -68,10 +71,12 @@ def main(_argv):
6871
codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
6972
out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height))
7073

74+
frame_num = 0
7175
while True:
7276
return_value, frame = vid.read()
7377
if return_value:
7478
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79+
frame_num += 1
7580
image = Image.fromarray(frame)
7681
else:
7782
print('Video has ended or failed, try a different video format!')
@@ -116,6 +121,24 @@ def main(_argv):
116121

117122
pred_bbox = [bboxes, scores.numpy()[0], classes.numpy()[0], valid_detections.numpy()[0]]
118123

124+
# if crop flag is enabled, crop each detection and save it as new image
125+
if FLAGS.crop:
126+
crop_rate = 150 # capture images every so many frames (ex. crop photos every 150 frames)
127+
crop_path = os.path.join(os.getcwd(), 'detections', 'crop', video_name)
128+
try:
129+
os.mkdir(crop_path)
130+
except FileExistsError:
131+
pass
132+
if frame_num % crop_rate == 0:
133+
final_path = os.path.join(crop_path, 'frame_' + str(frame_num))
134+
try:
135+
os.mkdir(final_path)
136+
except FileExistsError:
137+
pass
138+
crop_objects(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), pred_bbox, final_path)
139+
else:
140+
pass
141+
119142
if FLAGS.count:
120143
# count objects found
121144
counted_classes = count_objects(pred_bbox, by_class = False)

detections/crop/dog/bicycle_1.png

245 KB
Loading

detections/crop/dog/dog_1.png

106 KB
Loading

detections/crop/dog/truck_1.png

48.5 KB
Loading

0 commit comments

Comments
 (0)