Skip to content

Commit 3afd339

Browse files
Internal change
PiperOrigin-RevId: 485693087
1 parent 69bbdc1 commit 3afd339

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

official/projects/edgetpu/vision/modeling/custom_layers.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,130 @@ def call(self, inputs: Any) -> Any:
477477
axis=self.axis,
478478
output_type=self.output_type,
479479
name=self.name)
480+
481+
482+
_or = tf.maximum
483+
_and = tf.minimum
484+
_reduce_or = tf.reduce_max
485+
486+
487+
def _tensor_sum_vectors(a, b):
488+
return tf.reshape(a, [1, 1, 1, -1]) + tf.reshape(b, [1, 1, -1, 1])
489+
490+
491+
def _tensor_product_iou(boxes):
492+
"""Computes pairwise IOU.
493+
494+
Reason to use 4-D tensors is to follow TPU compiler preference.
495+
496+
Args:
497+
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
498+
499+
Returns:
500+
A 4-D float `Tensor` of shape `[1, 1, num_boxes, num_boxes]` containing
501+
pairwise IOU.
502+
"""
503+
boxes = tf.reshape(boxes, [-1, 4])
504+
boxes = tf.transpose(boxes, [1, 0])
505+
bottom, left, top, right = tf.split(boxes, 4, 0)
506+
height, width = top - bottom, right - left
507+
area = height * width
508+
area_sum = _tensor_sum_vectors(area, area)
509+
bottom_pad, left_pad, top_pad, right_pad = (
510+
tf.nn.relu(_tensor_sum_vectors(x, -x))
511+
for x in (-bottom, -left, top, right))
512+
height_pad, width_pad = bottom_pad + top_pad, left_pad + right_pad
513+
intersection = tf.nn.relu(height - height_pad) * tf.nn.relu(width - width_pad)
514+
union = area_sum - intersection
515+
iou = tf.math.divide(intersection, union + _same(union))
516+
return iou
517+
518+
519+
def _greater(x):
520+
"""Avoid non lowerable layers in boolean comparison.
521+
522+
Logical operation results in tensor of boolean type. However in serving such
523+
a tensors cannot be cast to values because of NNAPI specs.
524+
`tf.where` operation result in `select` instruction lowering, which not runs
525+
well on all generations of edge-tpus.
526+
527+
Args:
528+
x: any numeric tensor.
529+
530+
Returns:
531+
tf.where(x > tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
532+
"""
533+
x_clip = tf.minimum(tf.nn.relu(x), tf.constant(1, dtype=x.dtype))
534+
return -tf.math.floor(-x_clip)
535+
536+
537+
def _same(x):
538+
"""Avoid non lowerable layers in boolean equality.
539+
540+
Logical operation results in tensor of boolean type. However in serving such
541+
a tensors cannot be cast to values because of NNAPI specs.
542+
`tf.where` operation result in `select` instruction lowering, which not runs
543+
well on all generations of edge-tpus.
544+
545+
Args:
546+
x: any numeric tensor.
547+
548+
Returns:
549+
tf.where(x == tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
550+
"""
551+
x_clip = tf.minimum(tf.abs(x), tf.constant(1, dtype=x.dtype))
552+
return tf.constant(1, dtype=x.dtype) + tf.math.floor(-x_clip)
553+
554+
555+
def non_max_suppression_padded(boxes: tf.Tensor,
556+
scores: tf.Tensor,
557+
output_size: int,
558+
iou_threshold: float = 0.5) -> tf.Tensor:
559+
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
560+
561+
Prunes away boxes that have high intersection-over-union (IOU) overlap
562+
with boxes having higher score. Boxes are supplied as `[y1, x1, y2, x2]`,
563+
where `(y1, x1)` and `(y2, x2)` are the coordinates of any diagonal pair of
564+
box corners. Note that this algorithm is agnostic to the coordinate system.
565+
Thus translating or reflections of the coordinate system result in the same
566+
boxes being selected by the algorithm. The output of this operation is a
567+
set of integers indexing into the input collection of bounding boxes
568+
representing the selected boxes.
569+
570+
Set will be returned padded on the right with `-1` values. The bounding
571+
box coordinates corresponding to the selected indices can then be obtained
572+
using the `tf.gather` operation. For example:
573+
```python
574+
selected_indices = vision.modeling.layers.non_max_suppression_padded(
575+
boxes, scores, max_output_size, iou_threshold)
576+
selected_boxes = tf.gather(boxes, selected_indices)
577+
```
578+
579+
Args:
580+
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
581+
scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
582+
score corresponding to each box (each row of boxes).
583+
output_size: A scalar integer `Tensor` representing the maximum number of
584+
boxes to be selected by non-max suppression.
585+
iou_threshold: A 0-D float tensor representing the threshold for deciding
586+
whether boxes overlap too much with respect to IOU.
587+
588+
Returns:
589+
selected_indices: A 1-D integer `Tensor` of shape `[output_size]`
590+
representing the selected indices from the boxes tensor and `-1` values
591+
for the padding.
592+
"""
593+
order = tf.range(tf.size(scores), dtype=tf.float32)
594+
relative_order = _tensor_sum_vectors(order, -order)
595+
relative_scores = _tensor_sum_vectors(scores, -scores)
596+
similar = _greater(_tensor_product_iou(boxes) - iou_threshold)
597+
worse = _greater(relative_scores)
598+
same_later = _and(_same(relative_scores), _greater(relative_order))
599+
similar_worse_or_same_later = _and(similar, _or(worse, same_later))
600+
prunable = _reduce_or(similar_worse_or_same_later, axis=-1)
601+
remaining = (tf.constant(1.) - prunable)
602+
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
603+
top_k = tf.math.top_k(remaining * tf.exp(scores), output_size)
604+
return tf.squeeze(
605+
tf.cast(top_k.indices, top_k.values.dtype) * _greater(top_k.values) -
606+
_same(top_k.values))

official/projects/edgetpu/vision/modeling/custom_layers_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import itertools
1818

1919
from absl.testing import parameterized
20+
import numpy as np
2021
import tensorflow as tf
2122
from official.projects.edgetpu.vision.modeling import custom_layers
2223

@@ -186,5 +187,70 @@ def test_reference_match(self, shape, input_type, output_type):
186187
self.assertAllEqual(control_output, test_output)
187188

188189

190+
def random_boxes(n):
191+
a = tf.random.uniform(shape=[n, 2])
192+
b = tf.random.uniform(shape=[n, 2])
193+
l = tf.minimum(a, b)
194+
u = tf.maximum(a, b)
195+
return tf.concat([l, u], axis=-1)
196+
197+
198+
class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
199+
200+
@parameterized.parameters((16, 8, 500, 0.016), (31, 17, 300, 0.033),
201+
(71, 41, 300, 0.065), (150, 100, 250, 0.137),
202+
(300, 300, 250, 0.126), (600, 600, 100, 0.213))
203+
def test_reference_match(self, n, top, runs, max_deviation):
204+
"""Compares that new optimized method is close to reference method.
205+
206+
Runs two algorithms with same sets of input boxes and scores, and measures
207+
deviation between returned sets of prunned boxes.
208+
(*) Avoid flakiness with safe boundary (go/python-tips/048): deviation
209+
between two sets is a positive number, which may vary from test to test.
210+
Doing multiple runs expected to reduce average deviation variation following
211+
LLN theorem. Therefore by having first test run we know upper deviation
212+
bound which algorithm would not exceed until broken (in any feasible amount
213+
of time in the future). Use of this safe boundary makes test non-flaky.
214+
(**) Parametrized inputs description. See safe deviation choice is higher
215+
than absolute deviation to avoid flaky tesing.
216+
in # | out # | deflake # | test time | deviation | safe threshold
217+
---- | ----- | --------- | --------- | --------- | --------------
218+
18 | 8 | 500 | 6 sec | 0.4% | 1.6%
219+
31 | 17 | 300 | 7 sec | 1.0% | 3.3%
220+
71 | 41 | 300 | 7 sec | 3.4% | 6.5%
221+
150 | 100 | 250 | 7 sec | 8.2% | 13.7%
222+
300 | 300 | 250 | 10 sec | 7.4% | 12.6%
223+
600 | 600 | 100 | 9 sec | 9.6% | 21.3%
224+
225+
Args:
226+
n: number of boxes and scores on input of the algorithm.
227+
top: limit of output boxes count.
228+
runs: for the statistical testing number of runs to performs to avoid
229+
tests flakiness.
230+
max_deviation: mean limit on deviation between optimized and reference
231+
algorithms. Please read notes why this number may be set higher to avoid
232+
flaky testing.
233+
"""
234+
deviation_rate = 0
235+
for _ in range(runs):
236+
boxes = random_boxes(n)
237+
scores = tf.random.uniform(shape=[n])
238+
optimized = custom_layers.non_max_suppression_padded(boxes, scores, top)
239+
optimized = {*optimized.numpy().astype(int).tolist()} - {-1}
240+
reference = tf.image.non_max_suppression(boxes, scores, top)
241+
reference = {*reference.numpy().tolist()}
242+
deviation_rate += len(optimized ^ reference) / len(optimized | reference)
243+
deviation_rate = deviation_rate / runs
244+
# six sigma estimate via LLN theorem
245+
safe_margin = 6 * (deviation_rate / np.sqrt(runs) + 1 / runs)
246+
self.assertLess(
247+
deviation_rate,
248+
max_deviation,
249+
msg='Deviation rate between optimized and reference implementations is '
250+
'higher than expected. If you are tuning the test, recommended safe '
251+
'deviation rate is '
252+
f'{deviation_rate} + {safe_margin} = {deviation_rate + safe_margin}')
253+
254+
189255
if __name__ == '__main__':
190256
tf.test.main()

0 commit comments

Comments
 (0)