Skip to content

Commit 11ea523

Browse files
Support random_crop_resize function in preprocess_ops_3d.
PiperOrigin-RevId: 348500556
1 parent 63bdcfb commit 11ea523

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

official/vision/beta/ops/preprocess_ops_3d.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# ==============================================================================
1616
"""Utils for processing video dataset features."""
1717

18-
from typing import Optional
18+
from typing import Optional, Tuple
1919
import tensorflow as tf
2020

2121

@@ -217,6 +217,55 @@ def resize_fn():
217217
return frames
218218

219219

220+
def random_crop_resize(frames: tf.Tensor,
221+
output_h: int,
222+
output_w: int,
223+
num_frames: int,
224+
num_channels: int,
225+
aspect_ratio: Tuple[float, float],
226+
area_range: Tuple[float, float]) -> tf.Tensor:
227+
"""First crops clip with jittering and then resizes to (output_h, output_w).
228+
229+
Args:
230+
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
231+
output_h: Resized image height.
232+
output_w: Resized image width.
233+
num_frames: Number of input frames per clip.
234+
num_channels: Number of channels of the clip.
235+
aspect_ratio: Float tuple with the aspect range for cropping.
236+
area_range: Float tuple with the area range for cropping.
237+
Returns:
238+
A Tensor of shape [timesteps, output_h, output_w, channels] of type
239+
frames.dtype.
240+
"""
241+
shape = tf.shape(frames)
242+
seq_len, _, _, channels = shape[0], shape[1], shape[2], shape[3]
243+
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
244+
factor = output_w / output_h
245+
aspect_ratio = (aspect_ratio[0] * factor, aspect_ratio[1] * factor)
246+
sample_distorted_bbox = tf.image.sample_distorted_bounding_box(
247+
shape[1:],
248+
bounding_boxes=bbox,
249+
min_object_covered=0.1,
250+
aspect_ratio_range=aspect_ratio,
251+
area_range=area_range,
252+
max_attempts=100,
253+
use_image_if_no_bounding_boxes=True)
254+
bbox_begin, bbox_size, _ = sample_distorted_bbox
255+
offset_y, offset_x, _ = tf.unstack(bbox_begin)
256+
target_height, target_width, _ = tf.unstack(bbox_size)
257+
size = tf.convert_to_tensor((
258+
seq_len, target_height, target_width, channels))
259+
offset = tf.convert_to_tensor((
260+
0, offset_y, offset_x, 0))
261+
frames = tf.slice(frames, offset, size)
262+
frames = tf.cast(
263+
tf.image.resize(frames, (output_h, output_w)),
264+
frames.dtype)
265+
frames.set_shape((num_frames, output_h, output_w, num_channels))
266+
return frames
267+
268+
220269
def random_flip_left_right(
221270
frames: tf.Tensor,
222271
seed: Optional[int] = None) -> tf.Tensor:

official/vision/beta/ops/preprocess_ops_3d_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ def test_resize_smallest(self):
119119
self.assertAllEqual(resized_frames_3.shape, (6, 90, 120, 3))
120120
self.assertAllEqual(resized_frames_4.shape, (6, 60, 45, 3))
121121

122+
def test_random_crop_resize(self):
123+
resized_frames_1 = preprocess_ops_3d.random_crop_resize(
124+
self._frames, 256, 256, 6, 3, (0.5, 2), (0.3, 1))
125+
resized_frames_2 = preprocess_ops_3d.random_crop_resize(
126+
self._frames, 224, 224, 6, 3, (0.5, 2), (0.3, 1))
127+
resized_frames_3 = preprocess_ops_3d.random_crop_resize(
128+
self._frames, 256, 256, 6, 3, (0.8, 1.2), (0.3, 1))
129+
resized_frames_4 = preprocess_ops_3d.random_crop_resize(
130+
self._frames, 256, 256, 6, 3, (0.5, 2), (0.1, 1))
131+
self.assertAllEqual(resized_frames_1.shape, (6, 256, 256, 3))
132+
self.assertAllEqual(resized_frames_2.shape, (6, 224, 224, 3))
133+
self.assertAllEqual(resized_frames_3.shape, (6, 256, 256, 3))
134+
self.assertAllEqual(resized_frames_4.shape, (6, 256, 256, 3))
135+
122136
def test_random_flip_left_right(self):
123137
flipped_frames = preprocess_ops_3d.random_flip_left_right(self._frames)
124138

0 commit comments

Comments
 (0)