Skip to content

Commit 2d63ea7

Browse files
committed
Added ED-AFM figure scripts
1 parent 162231e commit 2d63ea7

23 files changed

+2672
-4
lines changed

Diff for: mlspm/datasets.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"AFM-ice-relaxed": "https://zenodo.org/records/10362511/files/relaxed_structures.tar.gz?download=1",
1616
"ASD-AFM-molecules": "https://zenodo.org/records/10562769/files/molecules.tar.gz?download=1",
1717
"AFM-camphor-exp": "https://zenodo.org/records/10562769/files/afm_camphor.tar.gz?download=1",
18-
"ED-AFM-molecules": "https://zenodo.org/records/10606443/files/molecules_rebias.tar.gz?download=1",
18+
"ED-AFM-molecules": "https://zenodo.org/records/10609676/files/molecules_rebias.tar.gz?download=1",
19+
"ED-AFM-data": "https://zenodo.org/records/10609676/files/edafm-data.tar.gz?download=1",
1920
}
2021

2122

@@ -44,7 +45,8 @@ def download_dataset(name: str, target_dir: PathLike):
4445
- ``'AFM-ice-relaxed'``: https://doi.org/10.5281/zenodo.10362511
4546
- ``'ASD-AFM-molecules'``: https://doi.org/10.5281/zenodo.10562769 - 'molecules.tar.gz'
4647
- ``'AFM-camphor-exp'``: https://doi.org/10.5281/zenodo.10562769 - 'afm_camphor.tar.gz'
47-
- ``'ED-AFM-molecules'``: https://doi.org/10.5281/zenodo.10606443
48+
- ``'ED-AFM-molecules'``: https://doi.org/10.5281/zenodo.10609676 - 'molecules_rebias.tar.gz'
49+
- ``'ED-AFM-data'``: https://doi.org/10.5281/zenodo.10609676 - 'edafm-data.tar.gz'
4850
4951
Arguments:
5052
name: Name of the dataset to download.

Diff for: papers/asd-afm/generate_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def on_sample_start(self):
3131
# Define simulator and image descriptor parameters
3232
scan_window = ((0, 0, 6.0), (15.875, 15.875, 7.9))
3333
scan_dim = (128, 128, 19)
34-
afmulator = AFMulator(pixPerAngstrome=5, scan_dim=scan_dim, scan_window=scan_window)
34+
afmulator = AFMulator(pixPerAngstrome=5, scan_dim=scan_dim, scan_window=scan_window, tipR0=[0.0, 0.0, 4.0])
3535
aux_maps = [
3636
AtomicDisks(scan_dim=scan_dim, scan_window=scan_window, zmin=-1.2, zmax_s=-1.2, diskMode="sphere"),
3737
vdwSpheres(scan_dim=scan_dim, scan_window=scan_window, zmin=-1.5),
@@ -41,7 +41,7 @@ def on_sample_start(self):
4141
"afmulator": afmulator,
4242
"aux_maps": aux_maps,
4343
"batch_size": 1,
44-
"distAbove": 4.3,
44+
"distAbove": 5.25,
4545
"iZPPs": [8],
4646
"Qs": [[-0.1, 0, 0, 0]],
4747
"QZs": [[0, 0, 0, 0]],

Diff for: papers/ed-afm/figures/afm_stacks.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pathlib import Path
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from mlspm.datasets import download_dataset
7+
8+
# # Set matplotlib font rendering to use LaTex
9+
# plt.rcParams.update({
10+
# "text.usetex": True,
11+
# "font.family": "serif",
12+
# "font.serif": ["Computer Modern Roman"]
13+
# })
14+
15+
if __name__ == "__main__":
16+
data_dir = Path("./edafm-data")
17+
fig_width = 160
18+
fontsize = 8
19+
dpi = 300
20+
21+
# Download data if not already there
22+
download_dataset("ED-AFM-data", data_dir)
23+
24+
# Load data
25+
bcb_CO = np.load(data_dir / "BCB" / "data_CO_exp.npz")
26+
bcb_Xe = np.load(data_dir / "BCB" / "data_Xe_exp.npz")
27+
ptcda_CO = np.load(data_dir / "PTCDA" / "data_CO_exp.npz")
28+
ptcda_Xe = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz")
29+
30+
fig_width = 0.1 / 2.54 * fig_width
31+
height_ratios = [2, 2, 2.45, 2.45]
32+
fig = plt.figure(figsize=(fig_width, 0.85 * sum(height_ratios)))
33+
fig_grid = fig.add_gridspec(4, 1, wspace=0, hspace=0.1, height_ratios=height_ratios)
34+
35+
# BCB plots
36+
for i, (sample, label) in enumerate(zip([bcb_CO, bcb_Xe], ["A", "B"])):
37+
d = sample["data"]
38+
l = sample["lengthX"]
39+
axes = fig_grid[i, 0].subgridspec(2, 8, wspace=0.02, hspace=0.02).subplots().flatten()
40+
for j, ax in enumerate(axes):
41+
if j < d.shape[-1]:
42+
ax.imshow(d[:, :, j].T, origin="lower", cmap="afmhot")
43+
ax.axis("off")
44+
axes[0].text(
45+
-0.3, 0.8, label, horizontalalignment="center", verticalalignment="center", transform=axes[0].transAxes, fontsize=fontsize
46+
)
47+
axes[0].plot([50, 50 + 5 / l * d.shape[0]], [470, 470], color="k")
48+
49+
# PTCDA plots
50+
for i, (sample, label) in enumerate(zip([ptcda_CO, ptcda_Xe], ["C", "D"])):
51+
d = sample["data"]
52+
l = sample["lengthX"]
53+
axes = fig_grid[i + 2, 0].subgridspec(3, 6, wspace=0.02, hspace=0.02).subplots().flatten()
54+
for j, ax in enumerate(axes):
55+
if j < d.shape[-1]:
56+
ax.imshow(d[:, :, j].T, origin="lower", cmap="afmhot")
57+
ax.axis("off")
58+
axes[0].text(
59+
-0.22, 0.7, label, horizontalalignment="center", verticalalignment="center", transform=axes[0].transAxes, fontsize=fontsize
60+
)
61+
axes[0].plot([20, 20 + 5 / l * d.shape[0]], [135, 135], color="k")
62+
63+
plt.savefig("afm_stacks.pdf", bbox_inches="tight", dpi=dpi)

Diff for: papers/ed-afm/figures/afm_stacks2.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
7+
from mlspm.datasets import download_dataset
8+
9+
# # Set matplotlib font rendering to use LaTex
10+
# plt.rcParams.update({
11+
# "text.usetex": True,
12+
# "font.family": "serif",
13+
# "font.serif": ["Computer Modern Roman"]
14+
# })
15+
16+
if __name__ == "__main__":
17+
data_dir = Path("./edafm-data")
18+
fig_width = 160
19+
fontsize = 8
20+
dpi = 300
21+
22+
# Download data if not already there
23+
download_dataset("ED-AFM-data", data_dir)
24+
25+
# Load data
26+
water_CO = np.load(data_dir / 'Water' / 'data_CO_exp.npz')
27+
water_Xe = np.load(data_dir / 'Water' / 'data_Xe_exp.npz')
28+
29+
fig = plt.figure(figsize=(0.1/2.54*fig_width, 5.0))
30+
fig_grid = fig.add_gridspec(2, 1, wspace=0, hspace=0.1)
31+
32+
# Water plots
33+
for i, (sample, label) in enumerate(zip([water_CO, water_Xe], ['E', 'F'])):
34+
d = sample['data']
35+
l = sample['lengthX']
36+
axes = fig_grid[i, 0].subgridspec(3, 8, wspace=0.02, hspace=0.02).subplots().flatten()
37+
for j, ax in enumerate(axes):
38+
if j < d.shape[-1]:
39+
ax.imshow(d[:,:,j].T, origin='lower', cmap='afmhot')
40+
ax.axis('off')
41+
axes[0].text(-0.3, 0.8, label, horizontalalignment='center',
42+
verticalalignment='center', transform=axes[0].transAxes, fontsize=fontsize)
43+
axes[0].plot([50, 50+5/l*d.shape[0]], [470, 470], color='k')
44+
45+
plt.savefig('afm_stacks2.pdf', bbox_inches='tight', dpi=dpi)

Diff for: papers/ed-afm/figures/background_gradient.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import ppafm.ml.AuxMap as aux
7+
import ppafm.ocl.field as FFcl
8+
import ppafm.ocl.oclUtils as oclu
9+
import ppafm.ocl.relax as oclr
10+
import torch
11+
from matplotlib import cm
12+
from ppafm.ml.Generator import InverseAFMtrainer
13+
from ppafm.ocl.AFMulator import AFMulator
14+
15+
import mlspm.preprocessing as pp
16+
from mlspm.models import EDAFMNet
17+
18+
# # Set matplotlib font rendering to use LaTex
19+
# plt.rcParams.update({
20+
# "text.usetex": True,
21+
# "font.family": "serif",
22+
# "font.serif": ["Computer Modern Roman"]
23+
# })
24+
25+
def apply_preprocessing_sim(batch):
26+
27+
X, Y, xyzs = batch
28+
29+
print(X[0].shape)
30+
31+
X = [x[..., 2:8] for x in X]
32+
33+
pp.add_norm(X)
34+
np.random.seed(0)
35+
pp.add_noise(X, c=0.08)
36+
37+
# Add background gradient
38+
c = 0.3
39+
angle = -np.pi / 2
40+
x, y = np.meshgrid(np.arange(0, X[0].shape[1]), np.arange(0, X[0].shape[2]), indexing="ij")
41+
n = [np.cos(angle), np.sin(angle), 1]
42+
z = -(n[0]*x + n[1]*y)
43+
z -= z.mean()
44+
z /= np.ptp(z)
45+
for x in X:
46+
x += z[None, :, :, None]*c*np.ptp(x)
47+
48+
return X, Y, xyzs
49+
50+
def apply_preprocessing_exp(X, real_dim):
51+
52+
# Pick slices
53+
x0_start, x1_start = 2, 0
54+
X[0] = X[0][..., x0_start:x0_start+6] # CO
55+
X[1] = X[1][..., x1_start:x1_start+6] # Xe
56+
57+
X = pp.interpolate_and_crop(X, real_dim)
58+
pp.add_norm(X)
59+
X = [x[:,:,6:78] for x in X]
60+
61+
return X
62+
63+
if __name__ == "__main__":
64+
65+
data_dir = Path("./edafm-data") # Path to data
66+
X_slices = [0, 3, 5] # Which AFM slices to plot
67+
tip_names = ["CO", "Xe"] # AFM tip types
68+
device = "cuda" # Device to run inference on
69+
fig_width = 140 # Figure width in mm
70+
fontsize = 8
71+
dpi = 300
72+
73+
# Initialize OpenCL environment on GPU
74+
env = oclu.OCLEnvironment( i_platform = 0 )
75+
FFcl.init(env)
76+
oclr.init(env)
77+
78+
afmulator_args = {
79+
"pixPerAngstrome" : 20,
80+
"scan_dim" : (176, 144, 19),
81+
"scan_window" : ((2.0, 2.0, 7.0), (24, 20, 8.9)),
82+
"df_steps" : 10,
83+
"tipR0" : [0.0, 0.0, 4.0]
84+
}
85+
86+
generator_kwargs = {
87+
"batch_size" : 1,
88+
"distAbove" : 5.25,
89+
"iZPPs" : [8, 54],
90+
"Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]],
91+
"QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]]
92+
}
93+
94+
# Paths to molecule xyz files
95+
molecules = [data_dir / "PTCDA" / "mol.xyz"]
96+
97+
# Define AFMulator
98+
afmulator = AFMulator(**afmulator_args)
99+
afmulator.npbc = (0,0,0)
100+
101+
# Define AuxMaps
102+
aux_maps = [
103+
aux.ESMapConstant(
104+
scan_dim = afmulator.scan_dim[:2],
105+
scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]],
106+
height = 4.0,
107+
vdW_cutoff = -2.0,
108+
Rpp = 1.0
109+
)
110+
]
111+
112+
# Define generator
113+
trainer = InverseAFMtrainer(afmulator, aux_maps, molecules, **generator_kwargs)
114+
115+
# Get simulation data
116+
sim_data = next(iter(trainer))
117+
X_sim, ref, xyzs = apply_preprocessing_sim(sim_data)
118+
X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim]
119+
120+
# Load experimental data and preprocess
121+
data1 = np.load(data_dir / "PTCDA" / "data_CO_exp.npz")
122+
X1 = data1["data"]
123+
afm_dim1 = (data1["lengthX"], data1["lengthY"])
124+
125+
data2 = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz")
126+
X2 = data2["data"]
127+
afm_dim2 = (data2["lengthX"], data2["lengthY"])
128+
129+
assert afm_dim1 == afm_dim2
130+
afm_dim = afm_dim1
131+
X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim)
132+
X_exp_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_exp]
133+
134+
# Load model with gradient augmentation
135+
model_grad = EDAFMNet(device=device, pretrained_weights="base")
136+
137+
# Load model without gradient augmentation
138+
model_no_grad = EDAFMNet(device=device, pretrained_weights="no-gradient")
139+
140+
with torch.no_grad():
141+
pred_sim_grad, attentions_sim_grad = model_grad(X_sim_cuda)
142+
pred_sim_no_grad, attentions_sim_no_grad = model_no_grad(X_sim_cuda)
143+
pred_exp, attentions_exp = model_no_grad(X_exp_cuda)
144+
pred_sim_grad = [p.cpu().numpy() for p in pred_sim_grad]
145+
pred_sim_no_grad = [p.cpu().numpy() for p in pred_sim_no_grad]
146+
pred_exp = [p.cpu().numpy() for p in pred_exp]
147+
attentions_sim_grad = [a.cpu().numpy() for a in attentions_sim_grad]
148+
attentions_sim_no_grad = [a.cpu().numpy() for a in attentions_sim_no_grad]
149+
attentions_exp = [a.cpu().numpy() for a in attentions_exp]
150+
151+
# Create figure grid
152+
fig_width = 0.1/2.54*fig_width
153+
width_ratios = [6, 4.4]
154+
fig = plt.figure(figsize=(fig_width, 6*fig_width/sum(width_ratios)))
155+
fig_grid = fig.add_gridspec(1, 2, wspace=0.3, hspace=0, width_ratios=width_ratios)
156+
left_grid = fig_grid[0, 0].subgridspec(2, 1, wspace=0, hspace=0.1)
157+
158+
pred_sim_grid = fig_grid[0, 1].subgridspec(2, 1, wspace=0, hspace=0.1)
159+
pred_sim_no_grad_ax, cbar_sim_no_grad_ax = pred_sim_grid[0, 0].subgridspec(1, 2, wspace=0.05,
160+
hspace=0, width_ratios=[1, 0.08]).subplots()
161+
pred_sim_grad_ax, cbar_sim_grad_ax = pred_sim_grid[1, 0].subgridspec(1, 2, wspace=0.05,
162+
hspace=0, width_ratios=[1, 0.08]).subplots()
163+
pred_exp_ax, cbar_exp_ax = left_grid[0, 0].subgridspec(1, 2, wspace=0.05, width_ratios=[1, 0.05]).subplots()
164+
afm_axes = left_grid[1, 0].subgridspec(len(X_sim), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False)
165+
166+
# Plot AFM
167+
for i, x in enumerate(X_sim):
168+
for j, s in enumerate(X_slices):
169+
170+
# Plot AFM slice
171+
im = afm_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")
172+
afm_axes[i, j].set_axis_off()
173+
174+
# Put tip names to the left of the AFM image rows
175+
afm_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center",
176+
verticalalignment="center", transform=afm_axes[i, 0].transAxes,
177+
rotation="vertical", fontsize=fontsize)
178+
179+
# Figure out ES data limits
180+
vmax_sim_no_grad = max(abs(pred_sim_no_grad[0].min()), abs(pred_sim_no_grad[0].max()))
181+
vmax_sim_grad = max(abs(pred_sim_grad[0].min()), abs(pred_sim_grad[0].max()))
182+
vmax_exp = max(abs(pred_exp[0].min()), abs(pred_exp[0].max()))
183+
vmin_sim_no_grad = -vmax_sim_no_grad
184+
vmin_sim_grad = -vmax_sim_grad
185+
vmin_exp = -vmax_exp
186+
187+
# Plot ES predictions
188+
pred_sim_no_grad_ax.imshow(pred_sim_no_grad[0][0].T, origin="lower", cmap="coolwarm",
189+
vmin=vmin_sim_no_grad, vmax=vmax_sim_no_grad)
190+
pred_sim_grad_ax.imshow(pred_sim_grad[0][0].T, origin="lower", cmap="coolwarm",
191+
vmin=vmin_sim_grad, vmax=vmax_sim_grad)
192+
pred_exp_ax.imshow(pred_exp[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin_exp, vmax=vmax_exp)
193+
194+
pred_sim_no_grad_ax.set_axis_off()
195+
pred_sim_grad_ax.set_axis_off()
196+
pred_exp_ax.set_axis_off()
197+
198+
# Plot ES Map colorbar for no grad prediction
199+
m_es = cm.ScalarMappable(cmap=cm.coolwarm)
200+
m_es.set_array((vmin_sim_no_grad, vmax_sim_no_grad))
201+
cbar = plt.colorbar(m_es, cax=cbar_sim_no_grad_ax)
202+
cbar.set_ticks([-0.1, 0.0, 0.1])
203+
cbar_sim_no_grad_ax.tick_params(labelsize=fontsize-1)
204+
cbar.set_label("V/Å", fontsize=fontsize)
205+
206+
# Plot ES Map colorbar for grad prediction
207+
m_es = cm.ScalarMappable(cmap=cm.coolwarm)
208+
m_es.set_array((vmin_sim_grad, vmax_sim_grad))
209+
cbar = plt.colorbar(m_es, cax=cbar_sim_grad_ax)
210+
cbar.set_ticks([-0.1, 0.0, 0.1])
211+
cbar_sim_grad_ax.tick_params(labelsize=fontsize-1)
212+
cbar.set_label("V/Å", fontsize=fontsize)
213+
214+
# Plot ES Map colorbar for experimental prediction
215+
m_es = cm.ScalarMappable(cmap=cm.coolwarm)
216+
m_es.set_array((vmin_exp, vmax_exp))
217+
cbar = plt.colorbar(m_es, cax=cbar_exp_ax)
218+
cbar.set_ticks([-0.04, 0.0, 0.04])
219+
cbar_exp_ax.tick_params(labelsize=fontsize-1)
220+
cbar.set_label("V/Å", fontsize=fontsize)
221+
222+
# Set labels
223+
pred_exp_ax.text(-0.06, 0.98, "A", horizontalalignment="center",
224+
verticalalignment="center", transform=pred_exp_ax.transAxes, fontsize=fontsize)
225+
afm_axes[0, 0].text(-0.2, 1.0, "B", horizontalalignment="center",
226+
verticalalignment="center", transform=afm_axes[0, 0].transAxes, fontsize=fontsize)
227+
pred_sim_no_grad_ax.text(-0.08, 0.98, "C", horizontalalignment="center",
228+
verticalalignment="center", transform=pred_sim_no_grad_ax.transAxes, fontsize=fontsize)
229+
pred_sim_grad_ax.text(-0.08, 0.98, "D", horizontalalignment="center",
230+
verticalalignment="center", transform=pred_sim_grad_ax.transAxes, fontsize=fontsize)
231+
232+
plt.savefig("background_gradient.pdf", bbox_inches="tight", dpi=dpi)

0 commit comments

Comments
 (0)