|
15 | 15 | # ==============================================================================
|
16 | 16 | """Utils for processing video dataset features."""
|
17 | 17 |
|
18 |
| -from typing import Optional |
| 18 | +from typing import Optional, Tuple |
19 | 19 | import tensorflow as tf
|
20 | 20 |
|
21 | 21 |
|
@@ -217,6 +217,55 @@ def resize_fn():
|
217 | 217 | return frames
|
218 | 218 |
|
219 | 219 |
|
| 220 | +def random_crop_resize(frames: tf.Tensor, |
| 221 | + output_h: int, |
| 222 | + output_w: int, |
| 223 | + num_frames: int, |
| 224 | + num_channels: int, |
| 225 | + aspect_ratio: Tuple[float, float], |
| 226 | + area_range: Tuple[float, float]) -> tf.Tensor: |
| 227 | + """First crops clip with jittering and then resizes to (output_h, output_w). |
| 228 | +
|
| 229 | + Args: |
| 230 | + frames: A Tensor of dimension [timesteps, input_h, input_w, channels]. |
| 231 | + output_h: Resized image height. |
| 232 | + output_w: Resized image width. |
| 233 | + num_frames: Number of input frames per clip. |
| 234 | + num_channels: Number of channels of the clip. |
| 235 | + aspect_ratio: Float tuple with the aspect range for cropping. |
| 236 | + area_range: Float tuple with the area range for cropping. |
| 237 | + Returns: |
| 238 | + A Tensor of shape [timesteps, output_h, output_w, channels] of type |
| 239 | + frames.dtype. |
| 240 | + """ |
| 241 | + shape = tf.shape(frames) |
| 242 | + seq_len, _, _, channels = shape[0], shape[1], shape[2], shape[3] |
| 243 | + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) |
| 244 | + factor = output_w / output_h |
| 245 | + aspect_ratio = (aspect_ratio[0] * factor, aspect_ratio[1] * factor) |
| 246 | + sample_distorted_bbox = tf.image.sample_distorted_bounding_box( |
| 247 | + shape[1:], |
| 248 | + bounding_boxes=bbox, |
| 249 | + min_object_covered=0.1, |
| 250 | + aspect_ratio_range=aspect_ratio, |
| 251 | + area_range=area_range, |
| 252 | + max_attempts=100, |
| 253 | + use_image_if_no_bounding_boxes=True) |
| 254 | + bbox_begin, bbox_size, _ = sample_distorted_bbox |
| 255 | + offset_y, offset_x, _ = tf.unstack(bbox_begin) |
| 256 | + target_height, target_width, _ = tf.unstack(bbox_size) |
| 257 | + size = tf.convert_to_tensor(( |
| 258 | + seq_len, target_height, target_width, channels)) |
| 259 | + offset = tf.convert_to_tensor(( |
| 260 | + 0, offset_y, offset_x, 0)) |
| 261 | + frames = tf.slice(frames, offset, size) |
| 262 | + frames = tf.cast( |
| 263 | + tf.image.resize(frames, (output_h, output_w)), |
| 264 | + frames.dtype) |
| 265 | + frames.set_shape((num_frames, output_h, output_w, num_channels)) |
| 266 | + return frames |
| 267 | + |
| 268 | + |
220 | 269 | def random_flip_left_right(
|
221 | 270 | frames: tf.Tensor,
|
222 | 271 | seed: Optional[int] = None) -> tf.Tensor:
|
|
0 commit comments