Skip to content

Commit 1da4976

Browse files
committed
Added ED-AFM data generation script.
1 parent 4fc47c3 commit 1da4976

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

Diff for: papers/ed-afm/generate_data.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import time
2+
from math import ceil
3+
from pathlib import Path
4+
5+
import numpy as np
6+
from ppafm.common import eVA_Nm
7+
from ppafm.ml.AuxMap import ESMapConstant
8+
from ppafm.ml.Generator import InverseAFMtrainer
9+
from ppafm.ocl.AFMulator import AFMulator
10+
11+
from mlspm.data_generation import TarWriter
12+
from mlspm.datasets import download_dataset
13+
14+
15+
class Trainer(InverseAFMtrainer):
16+
17+
def on_afm_start(self):
18+
# Use different lateral stiffness for Cl than CO and Xe
19+
if self.afmulator.iZPP in [8, 54]:
20+
afmulator.scanner.stiffness = np.array([0.25, 0.25, 0.0, 30.0], dtype=np.float32) / -eVA_Nm
21+
elif self.afmulator.iZPP == 17:
22+
afmulator.scanner.stiffness = np.array([0.5, 0.5, 0.0, 30.0], dtype=np.float32) / -eVA_Nm
23+
else:
24+
raise RuntimeError(f"Unknown tip {self.afmulator.iZPP}")
25+
26+
# Override to randomize tip distance and probe tilt
27+
def handle_distance(self):
28+
self.randomize_distance(delta=0.25)
29+
self.randomize_tip(max_tilt=0.5)
30+
super().handle_distance()
31+
32+
33+
if __name__ == "__main__":
34+
35+
# Path where molecule geometry files are saved
36+
mol_dir = Path("./molecules")
37+
38+
# Directory where to save data
39+
data_dir = Path(f"./data/")
40+
41+
# Define simulator and image descriptor parameters
42+
scan_window = ((0, 0, 6.0), (23.875, 23.875, 7.9))
43+
scan_dim = (192, 192, 19)
44+
afmulator = AFMulator(pixPerAngstrome=10, scan_dim=scan_dim, scan_window=scan_window, tipR0=[0.0, 0.0, 4.0])
45+
aux_maps = [
46+
ESMapConstant(
47+
scan_dim=afmulator.scan_dim[:2],
48+
scan_window=[afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]],
49+
height=4.0,
50+
vdW_cutoff=-2.0,
51+
Rpp=1.0,
52+
)
53+
]
54+
generator_arguments = {
55+
"afmulator": afmulator,
56+
"aux_maps": aux_maps,
57+
"batch_size": 1,
58+
"distAbove": 5.25,
59+
"iZPPs": [8, 54, 17], # CO, Xe, Cl
60+
"Qs": [[-10, 20, -10, 0], [30, -60, 30, 0], [-0.3, 0, 0, 0]],
61+
"QZs": [[0.1, 0, -0.1, 0], [0.1, 0, -0.1, 0], [0, 0, 0, 0]],
62+
}
63+
64+
# Number of tar file shards for each set
65+
target_shard_count = 8
66+
67+
# Make sure the save directory exists
68+
data_dir.mkdir(exist_ok=True, parents=True)
69+
70+
# Download the dataset. The extraction may take a while since there are ~235k files.
71+
download_dataset("ED-AFM-molecules", mol_dir)
72+
73+
# Paths to molecule xyz files
74+
train_paths = list((mol_dir / "train").glob("*.xyz"))
75+
val_paths = list((mol_dir / "validation").glob("*.xyz"))
76+
test_paths = list((mol_dir / "test").glob("*.xyz"))
77+
78+
# Generate dataset
79+
start_time = time.perf_counter()
80+
counter = 1
81+
for mode, molecules in zip(["train", "val", "test"], [train_paths, val_paths, test_paths]):
82+
83+
# Construct generator
84+
generator = Trainer(paths=molecules, **generator_arguments)
85+
86+
# Generate data
87+
max_count = ceil(len(generator) / target_shard_count)
88+
start_gen = time.perf_counter()
89+
with TarWriter(data_dir, f"{data_dir.name}-K-0_{mode}", max_count=max_count) as tar_writer:
90+
for i, (X, Y, xyz) in enumerate(generator):
91+
92+
# Get rid of the batch dimension
93+
X = [x[0] for x in X]
94+
Y = [y[0] for y in Y]
95+
xyz = xyz[0]
96+
97+
# Save information of the simulation parameters into the xyz comment line
98+
amp = generator.afmulator.amplitude
99+
R0 = generator.afmulator.tipR0
100+
kxy = generator.afmulator.scanner.stiffness[0]
101+
sw = generator.afmulator.scan_window
102+
comment_str = f"Scan window: [{sw[0]}, {sw[1]}], Amplitude: {amp}, tip R0: {R0}, kxy: {kxy}"
103+
104+
# Write the sample to a tar file
105+
tar_writer.add_sample(X, xyz, Y=Y, comment_str=comment_str)
106+
107+
if i % 100 == 0:
108+
elapsed = time.perf_counter() - start_gen
109+
eta = elapsed / (i + 1) * (len(generator) - i)
110+
print(
111+
f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, "
112+
f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s"
113+
)
114+
115+
print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}")
116+
117+
print("Total time taken: %d" % (time.perf_counter() - start_time))

0 commit comments

Comments
 (0)