Skip to content

Commit 5d0a2f0

Browse files
committed
Added training script for ED-AFM.
1 parent 1da4976 commit 5d0a2f0

File tree

3 files changed

+519
-3
lines changed

3 files changed

+519
-3
lines changed

Diff for: mlspm/preprocessing.py

+127-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import List, Tuple
2+
from typing import List, Literal, Optional, Tuple
33

44
import numpy as np
55
import scipy.ndimage as nimg
@@ -213,12 +213,136 @@ def interpolate_and_crop(
213213

214214

215215
def minimum_to_zero(Ys: List[np.ndarray]):
216-
'''
216+
"""
217217
Shift values in arrays such that minimum is at zero. In-place operation.
218218
219219
Arguments:
220220
Ys: Arrays of shape (batch_size, ...).
221-
'''
221+
"""
222222
for Y in Ys:
223223
for j in range(Y.shape[0]):
224224
Y[j] -= Y[j].min()
225+
226+
227+
def add_rotation_reflection(
228+
X: List[np.ndarray],
229+
Y: List[np.ndarray],
230+
reflections: bool = True,
231+
multiple: int = 2,
232+
crop: Optional[Tuple[int]] = None,
233+
per_batch_item: bool = False,
234+
):
235+
"""
236+
Augment batch with random rotations and reflections.
237+
238+
Arguments:
239+
X: AFM images to augment. Each array should be of shape ``(batch_size, x, y, z)``.
240+
Y: Reference image descriptors to augment. Each array should be of shape ``(batch, x, y)``.
241+
reflections: Whether to augment with reflections. If True, each rotation is randomly reflected with 50% probability.
242+
multiple: Multiplier for how many rotations to generate for every sample.
243+
crop: If not None, then output batch is cropped to specified size ``(x_size, y_size)`` in the middle of the image.
244+
per_batch_item: If True, rotation is randomized per batch item, otherwise same rotation for all.
245+
246+
Returns:
247+
Tuple (**X**, **Y**), where
248+
249+
- **X** - Batch of rotation-augmented AFM images of shape ``(batch*multiple, x_new, y_new, z)``.
250+
- **Y** - Batch of rotation-augmented reference image descriptors of shape ``(batch*multiple, x_new, y_new)``
251+
"""
252+
253+
X_aug = [[] for _ in range(len(X))]
254+
Y_aug = [[] for _ in range(len(Y))]
255+
256+
for _ in range(multiple):
257+
if per_batch_item:
258+
rotations = 360 * np.random.rand(len(X[0]))
259+
else:
260+
rotations = [360 * np.random.rand()] * len(X[0])
261+
if reflections:
262+
flip = np.random.randint(2)
263+
for k, x in enumerate(X):
264+
x = x.copy()
265+
for i in range(x.shape[0]):
266+
for j in range(x.shape[-1]):
267+
x[i, :, :, j] = np.array(Image.fromarray(x[i, :, :, j]).rotate(rotations[i], resample=Image.BICUBIC))
268+
if flip:
269+
x = x[:, :, ::-1]
270+
X_aug[k].append(x)
271+
for k, y in enumerate(Y):
272+
y = y.copy()
273+
for i in range(y.shape[0]):
274+
y[i, :, :] = np.array(Image.fromarray(y[i, :, :]).rotate(rotations[i], resample=Image.BICUBIC))
275+
if flip:
276+
y = y[:, :, ::-1]
277+
Y_aug[k].append(y)
278+
279+
X = [np.concatenate(x, axis=0) for x in X_aug]
280+
Y = [np.concatenate(y, axis=0) for y in Y_aug]
281+
282+
if crop is not None:
283+
x_start = (X[0].shape[1] - crop[0]) // 2
284+
y_start = (X[0].shape[2] - crop[1]) // 2
285+
X = [x[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for x in X]
286+
Y = [y[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for y in Y]
287+
288+
return X, Y
289+
290+
291+
def random_crop(
292+
X: List[np.ndarray],
293+
Y: List[np.ndarray],
294+
min_crop: float = 0.5,
295+
max_aspect: float = 2.0,
296+
multiple: int = 8,
297+
distribution: Literal["flat", "exp-log"] = "flat",
298+
):
299+
"""
300+
Randomly crop images in a batch to a different size and aspect ratio.
301+
302+
Arguments:
303+
X: AFM images to crop. Each array should be of shape ``(batch_size, x, y, z)``.
304+
Y: Reference image descriptors to crop. Each array should be of shape ``(batch, x, y)``.
305+
min_crop: Minimum crop size as a fraction of the original size.
306+
max_aspect: Maximum aspect ratio for crop. Cannot be more than 1/min_crop.
307+
multiple: The crop size is rounded down to the specified integer multiple.
308+
distribution: 'flat' or 'exp-log'. How aspect ratios are distributed. If 'flat', then distribution is random uniform
309+
between (1, max_aspect) and half of time is flipped. If 'exp-log', then distribution is exp of log of uniform
310+
distribution over (1/max_aspect, max_aspect). 'exp-log' is more biased towards square aspect ratios.
311+
312+
Returns:
313+
Tuple (**X**, **Y**), where
314+
315+
- **X** - Batch of cropped AFM images of shape ``(batch, x_new, y_new, z)``.
316+
- **Y** - Batch of cropped reference image descriptors of shape ``(batch, x_new, y_new)``.
317+
"""
318+
assert 0 < min_crop <= 1.0
319+
assert max_aspect >= 1.0
320+
assert 1 / min_crop >= max_aspect
321+
322+
if distribution == "flat":
323+
aspect = np.random.uniform(1, max_aspect)
324+
if np.random.rand() > 0.5:
325+
aspect = 1 / aspect
326+
elif distribution == "exp-log":
327+
aspect = np.exp(np.random.uniform(np.log(1 / max_aspect), np.log(max_aspect)))
328+
else:
329+
raise ValueError(f"Unrecognized aspect ratio distribution {distribution}")
330+
331+
x_size, y_size = X[0].shape[1], X[0].shape[2]
332+
if aspect > 1.0:
333+
height = int(np.random.uniform(int(min_crop * y_size), int(y_size / aspect)))
334+
width = int(height * aspect)
335+
else:
336+
width = int(np.random.uniform(int(min_crop * x_size), int(x_size * aspect)))
337+
height = int(width / aspect)
338+
339+
width = width - (width % multiple)
340+
height = height - (height % multiple)
341+
342+
start_x = int(np.random.uniform(0, x_size - width - 1e-6))
343+
start_y = int(np.random.uniform(0, y_size - height - 1e-6))
344+
345+
X = [x[:, start_x : start_x + width, start_y : start_y + height] for x in X]
346+
Y = [y[:, start_x : start_x + width, start_y : start_y + height] for y in Y]
347+
348+
return X, Y

Diff for: papers/ed-afm/run_train.sh

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/bin/bash
2+
3+
# Number of GPUs and the number of samples per batch per GPU (total batch size = N_GPU x BATCH_SIZE).
4+
N_GPU=1
5+
BATCH_SIZE=30
6+
7+
# Number of parallel workers per GPU for loading data from disk.
8+
N_WORKERS=8
9+
10+
export OMP_NUM_THREADS=1
11+
12+
torchrun \
13+
--standalone \
14+
--nnodes 1 \
15+
--nproc_per_node $N_GPU \
16+
--max_restarts 0 \
17+
train.py \
18+
--run_dir ./train \
19+
--data_dir ./data \
20+
--urls_train "data-K-0_train_{0..7}.tar" \
21+
--urls_val "data-K-0_val_{0..7}.tar" \
22+
--urls_test "data-K-0_test_{0..7}.tar" \
23+
--random_seed 0 \
24+
--train True \
25+
--test True \
26+
--predict True \
27+
--epochs 50 \
28+
--num_workers $N_WORKERS \
29+
--batch_size $BATCH_SIZE \
30+
--avg_best_epochs 3 \
31+
--pred_batches 3 \
32+
--lr 1e-4 \
33+
--loss_labels "ES" \
34+
--loss_weights 1.0 \
35+
--timings

0 commit comments

Comments
 (0)