|
| 1 | +import time |
| 2 | +from math import ceil |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from ppafm.common import eVA_Nm |
| 7 | +from ppafm.ml.AuxMap import ESMapConstant |
| 8 | +from ppafm.ml.Generator import InverseAFMtrainer |
| 9 | +from ppafm.ocl.AFMulator import AFMulator |
| 10 | + |
| 11 | +from mlspm.data_generation import TarWriter |
| 12 | +from mlspm.datasets import download_dataset |
| 13 | + |
| 14 | + |
| 15 | +class Trainer(InverseAFMtrainer): |
| 16 | + |
| 17 | + def on_afm_start(self): |
| 18 | + # Use different lateral stiffness for Cl than CO and Xe |
| 19 | + if self.afmulator.iZPP in [8, 54]: |
| 20 | + afmulator.scanner.stiffness = np.array([0.25, 0.25, 0.0, 30.0], dtype=np.float32) / -eVA_Nm |
| 21 | + elif self.afmulator.iZPP == 17: |
| 22 | + afmulator.scanner.stiffness = np.array([0.5, 0.5, 0.0, 30.0], dtype=np.float32) / -eVA_Nm |
| 23 | + else: |
| 24 | + raise RuntimeError(f"Unknown tip {self.afmulator.iZPP}") |
| 25 | + |
| 26 | + # Override to randomize tip distance and probe tilt |
| 27 | + def handle_distance(self): |
| 28 | + self.randomize_distance(delta=0.25) |
| 29 | + self.randomize_tip(max_tilt=0.5) |
| 30 | + super().handle_distance() |
| 31 | + |
| 32 | + |
| 33 | +if __name__ == "__main__": |
| 34 | + |
| 35 | + # Path where molecule geometry files are saved |
| 36 | + mol_dir = Path("./molecules") |
| 37 | + |
| 38 | + # Directory where to save data |
| 39 | + data_dir = Path(f"./data/") |
| 40 | + |
| 41 | + # Define simulator and image descriptor parameters |
| 42 | + scan_window = ((0, 0, 6.0), (23.875, 23.875, 7.9)) |
| 43 | + scan_dim = (192, 192, 19) |
| 44 | + afmulator = AFMulator(pixPerAngstrome=10, scan_dim=scan_dim, scan_window=scan_window, tipR0=[0.0, 0.0, 4.0]) |
| 45 | + aux_maps = [ |
| 46 | + ESMapConstant( |
| 47 | + scan_dim=afmulator.scan_dim[:2], |
| 48 | + scan_window=[afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], |
| 49 | + height=4.0, |
| 50 | + vdW_cutoff=-2.0, |
| 51 | + Rpp=1.0, |
| 52 | + ) |
| 53 | + ] |
| 54 | + generator_arguments = { |
| 55 | + "afmulator": afmulator, |
| 56 | + "aux_maps": aux_maps, |
| 57 | + "batch_size": 1, |
| 58 | + "distAbove": 5.25, |
| 59 | + "iZPPs": [8, 54, 17], # CO, Xe, Cl |
| 60 | + "Qs": [[-10, 20, -10, 0], [30, -60, 30, 0], [-0.3, 0, 0, 0]], |
| 61 | + "QZs": [[0.1, 0, -0.1, 0], [0.1, 0, -0.1, 0], [0, 0, 0, 0]], |
| 62 | + } |
| 63 | + |
| 64 | + # Number of tar file shards for each set |
| 65 | + target_shard_count = 8 |
| 66 | + |
| 67 | + # Make sure the save directory exists |
| 68 | + data_dir.mkdir(exist_ok=True, parents=True) |
| 69 | + |
| 70 | + # Download the dataset. The extraction may take a while since there are ~235k files. |
| 71 | + download_dataset("ED-AFM-molecules", mol_dir) |
| 72 | + |
| 73 | + # Paths to molecule xyz files |
| 74 | + train_paths = list((mol_dir / "train").glob("*.xyz")) |
| 75 | + val_paths = list((mol_dir / "validation").glob("*.xyz")) |
| 76 | + test_paths = list((mol_dir / "test").glob("*.xyz")) |
| 77 | + |
| 78 | + # Generate dataset |
| 79 | + start_time = time.perf_counter() |
| 80 | + counter = 1 |
| 81 | + for mode, molecules in zip(["train", "val", "test"], [train_paths, val_paths, test_paths]): |
| 82 | + |
| 83 | + # Construct generator |
| 84 | + generator = Trainer(paths=molecules, **generator_arguments) |
| 85 | + |
| 86 | + # Generate data |
| 87 | + max_count = ceil(len(generator) / target_shard_count) |
| 88 | + start_gen = time.perf_counter() |
| 89 | + with TarWriter(data_dir, f"{data_dir.name}-K-0_{mode}", max_count=max_count) as tar_writer: |
| 90 | + for i, (X, Y, xyz) in enumerate(generator): |
| 91 | + |
| 92 | + # Get rid of the batch dimension |
| 93 | + X = [x[0] for x in X] |
| 94 | + Y = [y[0] for y in Y] |
| 95 | + xyz = xyz[0] |
| 96 | + |
| 97 | + # Save information of the simulation parameters into the xyz comment line |
| 98 | + amp = generator.afmulator.amplitude |
| 99 | + R0 = generator.afmulator.tipR0 |
| 100 | + kxy = generator.afmulator.scanner.stiffness[0] |
| 101 | + sw = generator.afmulator.scan_window |
| 102 | + comment_str = f"Scan window: [{sw[0]}, {sw[1]}], Amplitude: {amp}, tip R0: {R0}, kxy: {kxy}" |
| 103 | + |
| 104 | + # Write the sample to a tar file |
| 105 | + tar_writer.add_sample(X, xyz, Y=Y, comment_str=comment_str) |
| 106 | + |
| 107 | + if i % 100 == 0: |
| 108 | + elapsed = time.perf_counter() - start_gen |
| 109 | + eta = elapsed / (i + 1) * (len(generator) - i) |
| 110 | + print( |
| 111 | + f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, " |
| 112 | + f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s" |
| 113 | + ) |
| 114 | + |
| 115 | + print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}") |
| 116 | + |
| 117 | + print("Total time taken: %d" % (time.perf_counter() - start_time)) |
0 commit comments