Skip to content

Commit dc6c341

Browse files
Merge pull request #10197 from PurdueDualityLab:detection_generator_pr_2
PiperOrigin-RevId: 395505920
2 parents a9c5469 + c4a9fa6 commit dc6c341

File tree

4 files changed

+334
-6
lines changed

4 files changed

+334
-6
lines changed

official/vision/beta/projects/yolo/modeling/backbones/darknet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Lint as: python3
1616
"""Contains definitions of Darknet Backbone Networks.
1717
18-
The models are inspired by ResNet, and CSPNet
18+
The models are inspired by ResNet and CSPNet.
1919
2020
Residual networks (ResNets) were proposed in:
2121
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Contains common building blocks for yolo layer (detection layer)."""
16+
import tensorflow as tf
17+
18+
from official.vision.beta.projects.yolo.ops import box_ops
19+
20+
21+
@tf.keras.utils.register_keras_serializable(package='yolo')
22+
class YoloLayer(tf.keras.Model):
23+
"""Yolo layer (detection generator)."""
24+
25+
def __init__(self,
26+
masks,
27+
anchors,
28+
classes,
29+
iou_thresh=0.0,
30+
ignore_thresh=0.7,
31+
truth_thresh=1.0,
32+
nms_thresh=0.6,
33+
max_delta=10.0,
34+
loss_type='ciou',
35+
iou_normalizer=1.0,
36+
cls_normalizer=1.0,
37+
obj_normalizer=1.0,
38+
use_scaled_loss=False,
39+
darknet=None,
40+
pre_nms_points=5000,
41+
label_smoothing=0.0,
42+
max_boxes=200,
43+
new_cords=False,
44+
path_scale=None,
45+
scale_xy=None,
46+
nms_type='greedy',
47+
objectness_smooth=False,
48+
**kwargs):
49+
"""Parameters for the loss functions used at each detection head output.
50+
51+
Args:
52+
masks: `List[int]` for the output level that this specific model output
53+
level.
54+
anchors: `List[List[int]]` for the anchor boxes that are used in the
55+
model.
56+
classes: `int` for the number of classes.
57+
iou_thresh: `float` to use many anchors per object if IoU(Obj, Anchor) >
58+
iou_thresh.
59+
ignore_thresh: `float` for the IOU value over which the loss is not
60+
propagated, and a detection is assumed to have been made.
61+
truth_thresh: `float` for the IOU value over which the loss is propagated
62+
despite a detection being made'.
63+
nms_thresh: `float` for the minimum IOU value for an overlap.
64+
max_delta: gradient clipping to apply to the box loss.
65+
loss_type: `str` for the typeof iou loss to use with in {ciou, diou,
66+
giou, iou}.
67+
iou_normalizer: `float` for how much to scale the loss on the IOU or the
68+
boxes.
69+
cls_normalizer: `float` for how much to scale the loss on the classes.
70+
obj_normalizer: `float` for how much to scale loss on the detection map.
71+
use_scaled_loss: `bool` for whether to use the scaled loss
72+
or the traditional loss.
73+
darknet: `bool` for whether to use the DarkNet or PyTorch loss function
74+
implementation.
75+
pre_nms_points: `int` number of top candidate detections per class before
76+
NMS.
77+
label_smoothing: `float` for how much to smooth the loss on the classes.
78+
max_boxes: `int` for the maximum number of boxes retained over all
79+
classes.
80+
new_cords: `bool` for using the ScaledYOLOv4 coordinates.
81+
path_scale: `dict` for the size of the input tensors. Defaults to
82+
precalulated values from the `mask`.
83+
scale_xy: dictionary `float` values inidcating how far each pixel can see
84+
outside of its containment of 1.0. a value of 1.2 indicates there is a
85+
20% extended radius around each pixel that this specific pixel can
86+
predict values for a center at. the center can range from 0 - value/2
87+
to 1 + value/2, this value is set in the yolo filter, and resused here.
88+
there should be one value for scale_xy for each level from min_level to
89+
max_level.
90+
nms_type: `str` for which non max suppression to use.
91+
objectness_smooth: `float` for how much to smooth the loss on the
92+
detection map.
93+
**kwargs: Addtional keyword arguments.
94+
95+
Return:
96+
loss: `float` for the actual loss.
97+
box_loss: `float` loss on the boxes used for metrics.
98+
conf_loss: `float` loss on the confidence used for metrics.
99+
class_loss: `float` loss on the classes used for metrics.
100+
avg_iou: `float` metric for the average iou between predictions
101+
and ground truth.
102+
avg_obj: `float` metric for the average confidence of the model
103+
for predictions.
104+
recall50: `float` metric for how accurate the model is.
105+
precision50: `float` metric for how precise the model is.
106+
"""
107+
super().__init__(**kwargs)
108+
self._masks = masks
109+
self._anchors = anchors
110+
self._thresh = iou_thresh
111+
self._ignore_thresh = ignore_thresh
112+
self._truth_thresh = truth_thresh
113+
self._iou_normalizer = iou_normalizer
114+
self._cls_normalizer = cls_normalizer
115+
self._obj_normalizer = obj_normalizer
116+
self._objectness_smooth = objectness_smooth
117+
self._nms_thresh = nms_thresh
118+
self._max_boxes = max_boxes
119+
self._max_delta = max_delta
120+
self._classes = classes
121+
self._loss_type = loss_type
122+
123+
self._use_scaled_loss = use_scaled_loss
124+
self._darknet = darknet
125+
126+
self._pre_nms_points = pre_nms_points
127+
self._label_smoothing = label_smoothing
128+
self._keys = list(masks.keys())
129+
self._len_keys = len(self._keys)
130+
self._new_cords = new_cords
131+
self._path_scale = path_scale or {
132+
key: 2**int(key) for key, _ in masks.items()
133+
}
134+
135+
self._nms_types = {
136+
'greedy': 1,
137+
'iou': 2,
138+
'giou': 3,
139+
'ciou': 4,
140+
'diou': 5,
141+
'class_independent': 6,
142+
'weighted_diou': 7
143+
}
144+
145+
self._nms_type = self._nms_types[nms_type]
146+
147+
self._scale_xy = scale_xy or {key: 1.0 for key, _ in masks.items()}
148+
149+
self._generator = {}
150+
self._len_mask = {}
151+
for key in self._keys:
152+
anchors = [self._anchors[mask] for mask in self._masks[key]]
153+
self._generator[key] = self.get_generators(anchors, self._path_scale[key], # pylint: disable=assignment-from-none
154+
key)
155+
self._len_mask[key] = len(self._masks[key])
156+
return
157+
158+
def get_generators(self, anchors, path_scale, path_key):
159+
return None
160+
161+
def rm_nan_inf(self, x, val=0.0):
162+
x = tf.where(tf.math.is_nan(x), tf.cast(val, dtype=x.dtype), x)
163+
x = tf.where(tf.math.is_inf(x), tf.cast(val, dtype=x.dtype), x)
164+
return x
165+
166+
def parse_prediction_path(self, key, inputs):
167+
shape = inputs.get_shape().as_list()
168+
height, width = shape[1], shape[2]
169+
170+
len_mask = self._len_mask[key]
171+
172+
# reshape the yolo output to (batchsize,
173+
# width,
174+
# height,
175+
# number_anchors,
176+
# remaining_points)
177+
178+
data = tf.reshape(inputs, [-1, height, width, len_mask, self._classes + 5])
179+
180+
# split the yolo detections into boxes, object score map, classes
181+
boxes, obns_scores, class_scores = tf.split(
182+
data, [4, 1, self._classes], axis=-1)
183+
184+
# determine the number of classes
185+
classes = class_scores.get_shape().as_list()[-1]
186+
187+
# convert boxes from yolo(x, y, w. h) to tensorflow(ymin, xmin, ymax, xmax)
188+
boxes = box_ops.xcycwh_to_yxyx(boxes)
189+
190+
# activate and detection map
191+
obns_scores = tf.math.sigmoid(obns_scores)
192+
193+
# threshold the detection map
194+
obns_mask = tf.cast(obns_scores > self._thresh, obns_scores.dtype)
195+
196+
# convert detection map to class detection probabailities
197+
class_scores = tf.math.sigmoid(class_scores) * obns_mask * obns_scores
198+
class_scores *= tf.cast(class_scores > self._thresh, class_scores.dtype)
199+
200+
fill = height * width * len_mask
201+
# platten predictions to [batchsize, N, -1] for non max supression
202+
boxes = tf.reshape(boxes, [-1, fill, 4])
203+
class_scores = tf.reshape(class_scores, [-1, fill, classes])
204+
obns_scores = tf.reshape(obns_scores, [-1, fill])
205+
206+
return obns_scores, boxes, class_scores
207+
208+
def call(self, inputs):
209+
boxes = []
210+
class_scores = []
211+
object_scores = []
212+
levels = list(inputs.keys())
213+
min_level = int(min(levels))
214+
max_level = int(max(levels))
215+
216+
# aggregare boxes over each scale
217+
for i in range(min_level, max_level + 1):
218+
key = str(i)
219+
object_scores_, boxes_, class_scores_ = self.parse_prediction_path(
220+
key, inputs[key])
221+
boxes.append(boxes_)
222+
class_scores.append(class_scores_)
223+
object_scores.append(object_scores_)
224+
225+
# colate all predicitons
226+
boxes = tf.concat(boxes, axis=1)
227+
object_scores = tf.keras.backend.concatenate(object_scores, axis=1)
228+
class_scores = tf.keras.backend.concatenate(class_scores, axis=1)
229+
230+
# greedy NMS
231+
boxes = tf.cast(boxes, dtype=tf.float32)
232+
class_scores = tf.cast(class_scores, dtype=tf.float32)
233+
nms_items = tf.image.combined_non_max_suppression(
234+
tf.expand_dims(boxes, axis=-2),
235+
class_scores,
236+
self._pre_nms_points,
237+
self._max_boxes,
238+
iou_threshold=self._nms_thresh,
239+
score_threshold=self._thresh)
240+
# cast the boxes and predicitons abck to original datatype
241+
boxes = tf.cast(nms_items.nmsed_boxes, object_scores.dtype)
242+
class_scores = tf.cast(nms_items.nmsed_classes, object_scores.dtype)
243+
object_scores = tf.cast(nms_items.nmsed_scores, object_scores.dtype)
244+
245+
# compute the number of valid detections
246+
num_detections = tf.math.reduce_sum(tf.math.ceil(object_scores), axis=-1)
247+
248+
# format and return
249+
return {
250+
'bbox': boxes,
251+
'classes': class_scores,
252+
'confidence': object_scores,
253+
'num_detections': num_detections,
254+
}
255+
256+
@property
257+
def losses(self):
258+
"""Generates a dictionary of losses to apply to each path.
259+
260+
Done in the detection generator because all parameters are the same
261+
across both loss and detection generator.
262+
"""
263+
return None
264+
265+
def get_config(self):
266+
return {
267+
'masks': dict(self._masks),
268+
'anchors': [list(a) for a in self._anchors],
269+
'thresh': self._thresh,
270+
'max_boxes': self._max_boxes,
271+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for yolo detection generator."""
16+
17+
from absl.testing import parameterized
18+
import tensorflow as tf
19+
20+
from official.vision.beta.projects.yolo.modeling.layers import detection_generator as dg
21+
22+
23+
class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
24+
25+
@parameterized.parameters(
26+
(True),
27+
(False),
28+
)
29+
def test_network_creation(self, nms):
30+
"""Test creation of ResNet family models."""
31+
tf.keras.backend.set_image_data_format('channels_last')
32+
input_shape = {
33+
'3': [1, 52, 52, 255],
34+
'4': [1, 26, 26, 255],
35+
'5': [1, 13, 13, 255]
36+
}
37+
classes = 80
38+
masks = {'3': [0, 1, 2], '4': [3, 4, 5], '5': [6, 7, 8]}
39+
anchors = [[12.0, 19.0], [31.0, 46.0], [96.0, 54.0], [46.0, 114.0],
40+
[133.0, 127.0], [79.0, 225.0], [301.0, 150.0], [172.0, 286.0],
41+
[348.0, 340.0]]
42+
layer = dg.YoloLayer(masks, anchors, classes, max_boxes=10)
43+
44+
inputs = {}
45+
for key in input_shape:
46+
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
47+
48+
endpoints = layer(inputs)
49+
50+
boxes = endpoints['bbox']
51+
classes = endpoints['classes']
52+
53+
self.assertAllEqual(boxes.shape.as_list(), [1, 10, 4])
54+
self.assertAllEqual(classes.shape.as_list(), [1, 10])
55+
56+
57+
if __name__ == '__main__':
58+
tf.test.main()

official/vision/beta/projects/yolo/modeling/layers/nn_blocks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
# Lint as: python3
1616
"""Contains common building blocks for yolo neural networks."""
17-
1817
from typing import Callable, List
1918
import tensorflow as tf
2019
from official.modeling import tf_utils
@@ -549,7 +548,7 @@ def __init__(self,
549548
550549
Args:
551550
filters: integer for output depth, or the number of features to learn
552-
filter_scale: integer dicating (filters//2) or the number of filters in
551+
filter_scale: integer dictating (filters//2) or the number of filters in
553552
the partial feature stack.
554553
activation: string for activation function to use in layer.
555554
kernel_initializer: string to indicate which function to use to
@@ -676,8 +675,8 @@ def __init__(self,
676675
"""Initializer for CSPConnect block.
677676
678677
Args:
679-
filters: integer for output depth, or the number of features to learn
680-
filter_scale: integer dicating (filters//2) or the number of filters in
678+
filters: integer for output depth, or the number of features to learn.
679+
filter_scale: integer dictating (filters//2) or the number of filters in
681680
the partial feature stack.
682681
drop_final: `bool`, whether to drop final conv layer.
683682
drop_first: `bool`, whether to drop first conv layer.
@@ -801,7 +800,7 @@ def __init__(self,
801800
model_to_wrap: callable Model or a list of callable objects that will
802801
process the output of CSPRoute, and be input into CSPConnect.
803802
list will be called sequentially.
804-
filter_scale: integer dicating (filters//2) or the number of filters in
803+
filter_scale: integer dictating (filters//2) or the number of filters in
805804
the partial feature stack.
806805
activation: string for activation function to use in layer.
807806
kernel_initializer: string to indicate which function to use to initialize

0 commit comments

Comments
 (0)