|
1 | 1 | import random
|
2 |
| -from typing import List, Tuple |
| 2 | +from typing import List, Literal, Optional, Tuple |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import scipy.ndimage as nimg
|
@@ -213,12 +213,136 @@ def interpolate_and_crop(
|
213 | 213 |
|
214 | 214 |
|
215 | 215 | def minimum_to_zero(Ys: List[np.ndarray]):
|
216 |
| - ''' |
| 216 | + """ |
217 | 217 | Shift values in arrays such that minimum is at zero. In-place operation.
|
218 | 218 |
|
219 | 219 | Arguments:
|
220 | 220 | Ys: Arrays of shape (batch_size, ...).
|
221 |
| - ''' |
| 221 | + """ |
222 | 222 | for Y in Ys:
|
223 | 223 | for j in range(Y.shape[0]):
|
224 | 224 | Y[j] -= Y[j].min()
|
| 225 | + |
| 226 | + |
| 227 | +def add_rotation_reflection( |
| 228 | + X: List[np.ndarray], |
| 229 | + Y: List[np.ndarray], |
| 230 | + reflections: bool = True, |
| 231 | + multiple: int = 2, |
| 232 | + crop: Optional[Tuple[int]] = None, |
| 233 | + per_batch_item: bool = False, |
| 234 | +): |
| 235 | + """ |
| 236 | + Augment batch with random rotations and reflections. |
| 237 | +
|
| 238 | + Arguments: |
| 239 | + X: AFM images to augment. Each array should be of shape ``(batch_size, x, y, z)``. |
| 240 | + Y: Reference image descriptors to augment. Each array should be of shape ``(batch, x, y)``. |
| 241 | + reflections: Whether to augment with reflections. If True, each rotation is randomly reflected with 50% probability. |
| 242 | + multiple: Multiplier for how many rotations to generate for every sample. |
| 243 | + crop: If not None, then output batch is cropped to specified size ``(x_size, y_size)`` in the middle of the image. |
| 244 | + per_batch_item: If True, rotation is randomized per batch item, otherwise same rotation for all. |
| 245 | +
|
| 246 | + Returns: |
| 247 | + Tuple (**X**, **Y**), where |
| 248 | +
|
| 249 | + - **X** - Batch of rotation-augmented AFM images of shape ``(batch*multiple, x_new, y_new, z)``. |
| 250 | + - **Y** - Batch of rotation-augmented reference image descriptors of shape ``(batch*multiple, x_new, y_new)`` |
| 251 | + """ |
| 252 | + |
| 253 | + X_aug = [[] for _ in range(len(X))] |
| 254 | + Y_aug = [[] for _ in range(len(Y))] |
| 255 | + |
| 256 | + for _ in range(multiple): |
| 257 | + if per_batch_item: |
| 258 | + rotations = 360 * np.random.rand(len(X[0])) |
| 259 | + else: |
| 260 | + rotations = [360 * np.random.rand()] * len(X[0]) |
| 261 | + if reflections: |
| 262 | + flip = np.random.randint(2) |
| 263 | + for k, x in enumerate(X): |
| 264 | + x = x.copy() |
| 265 | + for i in range(x.shape[0]): |
| 266 | + for j in range(x.shape[-1]): |
| 267 | + x[i, :, :, j] = np.array(Image.fromarray(x[i, :, :, j]).rotate(rotations[i], resample=Image.BICUBIC)) |
| 268 | + if flip: |
| 269 | + x = x[:, :, ::-1] |
| 270 | + X_aug[k].append(x) |
| 271 | + for k, y in enumerate(Y): |
| 272 | + y = y.copy() |
| 273 | + for i in range(y.shape[0]): |
| 274 | + y[i, :, :] = np.array(Image.fromarray(y[i, :, :]).rotate(rotations[i], resample=Image.BICUBIC)) |
| 275 | + if flip: |
| 276 | + y = y[:, :, ::-1] |
| 277 | + Y_aug[k].append(y) |
| 278 | + |
| 279 | + X = [np.concatenate(x, axis=0) for x in X_aug] |
| 280 | + Y = [np.concatenate(y, axis=0) for y in Y_aug] |
| 281 | + |
| 282 | + if crop is not None: |
| 283 | + x_start = (X[0].shape[1] - crop[0]) // 2 |
| 284 | + y_start = (X[0].shape[2] - crop[1]) // 2 |
| 285 | + X = [x[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for x in X] |
| 286 | + Y = [y[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for y in Y] |
| 287 | + |
| 288 | + return X, Y |
| 289 | + |
| 290 | + |
| 291 | +def random_crop( |
| 292 | + X: List[np.ndarray], |
| 293 | + Y: List[np.ndarray], |
| 294 | + min_crop: float = 0.5, |
| 295 | + max_aspect: float = 2.0, |
| 296 | + multiple: int = 8, |
| 297 | + distribution: Literal["flat", "exp-log"] = "flat", |
| 298 | +): |
| 299 | + """ |
| 300 | + Randomly crop images in a batch to a different size and aspect ratio. |
| 301 | +
|
| 302 | + Arguments: |
| 303 | + X: AFM images to crop. Each array should be of shape ``(batch_size, x, y, z)``. |
| 304 | + Y: Reference image descriptors to crop. Each array should be of shape ``(batch, x, y)``. |
| 305 | + min_crop: Minimum crop size as a fraction of the original size. |
| 306 | + max_aspect: Maximum aspect ratio for crop. Cannot be more than 1/min_crop. |
| 307 | + multiple: The crop size is rounded down to the specified integer multiple. |
| 308 | + distribution: 'flat' or 'exp-log'. How aspect ratios are distributed. If 'flat', then distribution is random uniform |
| 309 | + between (1, max_aspect) and half of time is flipped. If 'exp-log', then distribution is exp of log of uniform |
| 310 | + distribution over (1/max_aspect, max_aspect). 'exp-log' is more biased towards square aspect ratios. |
| 311 | +
|
| 312 | + Returns: |
| 313 | + Tuple (**X**, **Y**), where |
| 314 | +
|
| 315 | + - **X** - Batch of cropped AFM images of shape ``(batch, x_new, y_new, z)``. |
| 316 | + - **Y** - Batch of cropped reference image descriptors of shape ``(batch, x_new, y_new)``. |
| 317 | + """ |
| 318 | + assert 0 < min_crop <= 1.0 |
| 319 | + assert max_aspect >= 1.0 |
| 320 | + assert 1 / min_crop >= max_aspect |
| 321 | + |
| 322 | + if distribution == "flat": |
| 323 | + aspect = np.random.uniform(1, max_aspect) |
| 324 | + if np.random.rand() > 0.5: |
| 325 | + aspect = 1 / aspect |
| 326 | + elif distribution == "exp-log": |
| 327 | + aspect = np.exp(np.random.uniform(np.log(1 / max_aspect), np.log(max_aspect))) |
| 328 | + else: |
| 329 | + raise ValueError(f"Unrecognized aspect ratio distribution {distribution}") |
| 330 | + |
| 331 | + x_size, y_size = X[0].shape[1], X[0].shape[2] |
| 332 | + if aspect > 1.0: |
| 333 | + height = int(np.random.uniform(int(min_crop * y_size), int(y_size / aspect))) |
| 334 | + width = int(height * aspect) |
| 335 | + else: |
| 336 | + width = int(np.random.uniform(int(min_crop * x_size), int(x_size * aspect))) |
| 337 | + height = int(width / aspect) |
| 338 | + |
| 339 | + width = width - (width % multiple) |
| 340 | + height = height - (height % multiple) |
| 341 | + |
| 342 | + start_x = int(np.random.uniform(0, x_size - width - 1e-6)) |
| 343 | + start_y = int(np.random.uniform(0, y_size - height - 1e-6)) |
| 344 | + |
| 345 | + X = [x[:, start_x : start_x + width, start_y : start_y + height] for x in X] |
| 346 | + Y = [y[:, start_x : start_x + width, start_y : start_y + height] for y in Y] |
| 347 | + |
| 348 | + return X, Y |
0 commit comments