1
+
2
+ from pathlib import Path
3
+
4
+ import matplotlib .pyplot as plt
5
+ import numpy as np
6
+ import ppafm .ml .AuxMap as aux
7
+ import ppafm .ocl .field as FFcl
8
+ import ppafm .ocl .oclUtils as oclu
9
+ import ppafm .ocl .relax as oclr
10
+ import torch
11
+ from matplotlib import cm
12
+ from ppafm .ml .Generator import InverseAFMtrainer
13
+ from ppafm .ocl .AFMulator import AFMulator
14
+
15
+ import mlspm .preprocessing as pp
16
+ from mlspm .models import EDAFMNet
17
+
18
+ # # Set matplotlib font rendering to use LaTex
19
+ # plt.rcParams.update({
20
+ # "text.usetex": True,
21
+ # "font.family": "serif",
22
+ # "font.serif": ["Computer Modern Roman"]
23
+ # })
24
+
25
+ def apply_preprocessing_sim (batch ):
26
+
27
+ X , Y , xyzs = batch
28
+
29
+ print (X [0 ].shape )
30
+
31
+ X = [x [..., 2 :8 ] for x in X ]
32
+
33
+ pp .add_norm (X )
34
+ np .random .seed (0 )
35
+ pp .add_noise (X , c = 0.08 )
36
+
37
+ # Add background gradient
38
+ c = 0.3
39
+ angle = - np .pi / 2
40
+ x , y = np .meshgrid (np .arange (0 , X [0 ].shape [1 ]), np .arange (0 , X [0 ].shape [2 ]), indexing = "ij" )
41
+ n = [np .cos (angle ), np .sin (angle ), 1 ]
42
+ z = - (n [0 ]* x + n [1 ]* y )
43
+ z -= z .mean ()
44
+ z /= np .ptp (z )
45
+ for x in X :
46
+ x += z [None , :, :, None ]* c * np .ptp (x )
47
+
48
+ return X , Y , xyzs
49
+
50
+ def apply_preprocessing_exp (X , real_dim ):
51
+
52
+ # Pick slices
53
+ x0_start , x1_start = 2 , 0
54
+ X [0 ] = X [0 ][..., x0_start :x0_start + 6 ] # CO
55
+ X [1 ] = X [1 ][..., x1_start :x1_start + 6 ] # Xe
56
+
57
+ X = pp .interpolate_and_crop (X , real_dim )
58
+ pp .add_norm (X )
59
+ X = [x [:,:,6 :78 ] for x in X ]
60
+
61
+ return X
62
+
63
+ if __name__ == "__main__" :
64
+
65
+ data_dir = Path ("./edafm-data" ) # Path to data
66
+ X_slices = [0 , 3 , 5 ] # Which AFM slices to plot
67
+ tip_names = ["CO" , "Xe" ] # AFM tip types
68
+ device = "cuda" # Device to run inference on
69
+ fig_width = 140 # Figure width in mm
70
+ fontsize = 8
71
+ dpi = 300
72
+
73
+ # Initialize OpenCL environment on GPU
74
+ env = oclu .OCLEnvironment ( i_platform = 0 )
75
+ FFcl .init (env )
76
+ oclr .init (env )
77
+
78
+ afmulator_args = {
79
+ "pixPerAngstrome" : 20 ,
80
+ "scan_dim" : (176 , 144 , 19 ),
81
+ "scan_window" : ((2.0 , 2.0 , 7.0 ), (24 , 20 , 8.9 )),
82
+ "df_steps" : 10 ,
83
+ "tipR0" : [0.0 , 0.0 , 4.0 ]
84
+ }
85
+
86
+ generator_kwargs = {
87
+ "batch_size" : 1 ,
88
+ "distAbove" : 5.25 ,
89
+ "iZPPs" : [8 , 54 ],
90
+ "Qs" : [[ - 10 , 20 , - 10 , 0 ], [ 30 , - 60 , 30 , 0 ]],
91
+ "QZs" : [[ 0.1 , 0 , - 0.1 , 0 ], [ 0.1 , 0 , - 0.1 , 0 ]]
92
+ }
93
+
94
+ # Paths to molecule xyz files
95
+ molecules = [data_dir / "PTCDA" / "mol.xyz" ]
96
+
97
+ # Define AFMulator
98
+ afmulator = AFMulator (** afmulator_args )
99
+ afmulator .npbc = (0 ,0 ,0 )
100
+
101
+ # Define AuxMaps
102
+ aux_maps = [
103
+ aux .ESMapConstant (
104
+ scan_dim = afmulator .scan_dim [:2 ],
105
+ scan_window = [afmulator .scan_window [0 ][:2 ], afmulator .scan_window [1 ][:2 ]],
106
+ height = 4.0 ,
107
+ vdW_cutoff = - 2.0 ,
108
+ Rpp = 1.0
109
+ )
110
+ ]
111
+
112
+ # Define generator
113
+ trainer = InverseAFMtrainer (afmulator , aux_maps , molecules , ** generator_kwargs )
114
+
115
+ # Get simulation data
116
+ sim_data = next (iter (trainer ))
117
+ X_sim , ref , xyzs = apply_preprocessing_sim (sim_data )
118
+ X_sim_cuda = [torch .from_numpy (x ).unsqueeze (1 ).to (device ) for x in X_sim ]
119
+
120
+ # Load experimental data and preprocess
121
+ data1 = np .load (data_dir / "PTCDA" / "data_CO_exp.npz" )
122
+ X1 = data1 ["data" ]
123
+ afm_dim1 = (data1 ["lengthX" ], data1 ["lengthY" ])
124
+
125
+ data2 = np .load (data_dir / "PTCDA" / "data_Xe_exp.npz" )
126
+ X2 = data2 ["data" ]
127
+ afm_dim2 = (data2 ["lengthX" ], data2 ["lengthY" ])
128
+
129
+ assert afm_dim1 == afm_dim2
130
+ afm_dim = afm_dim1
131
+ X_exp = apply_preprocessing_exp ([X1 [None ], X2 [None ]], afm_dim )
132
+ X_exp_cuda = [torch .from_numpy (x .astype (np .float32 )).unsqueeze (1 ).to (device ) for x in X_exp ]
133
+
134
+ # Load model with gradient augmentation
135
+ model_grad = EDAFMNet (device = device , pretrained_weights = "base" )
136
+
137
+ # Load model without gradient augmentation
138
+ model_no_grad = EDAFMNet (device = device , pretrained_weights = "no-gradient" )
139
+
140
+ with torch .no_grad ():
141
+ pred_sim_grad , attentions_sim_grad = model_grad (X_sim_cuda )
142
+ pred_sim_no_grad , attentions_sim_no_grad = model_no_grad (X_sim_cuda )
143
+ pred_exp , attentions_exp = model_no_grad (X_exp_cuda )
144
+ pred_sim_grad = [p .cpu ().numpy () for p in pred_sim_grad ]
145
+ pred_sim_no_grad = [p .cpu ().numpy () for p in pred_sim_no_grad ]
146
+ pred_exp = [p .cpu ().numpy () for p in pred_exp ]
147
+ attentions_sim_grad = [a .cpu ().numpy () for a in attentions_sim_grad ]
148
+ attentions_sim_no_grad = [a .cpu ().numpy () for a in attentions_sim_no_grad ]
149
+ attentions_exp = [a .cpu ().numpy () for a in attentions_exp ]
150
+
151
+ # Create figure grid
152
+ fig_width = 0.1 / 2.54 * fig_width
153
+ width_ratios = [6 , 4.4 ]
154
+ fig = plt .figure (figsize = (fig_width , 6 * fig_width / sum (width_ratios )))
155
+ fig_grid = fig .add_gridspec (1 , 2 , wspace = 0.3 , hspace = 0 , width_ratios = width_ratios )
156
+ left_grid = fig_grid [0 , 0 ].subgridspec (2 , 1 , wspace = 0 , hspace = 0.1 )
157
+
158
+ pred_sim_grid = fig_grid [0 , 1 ].subgridspec (2 , 1 , wspace = 0 , hspace = 0.1 )
159
+ pred_sim_no_grad_ax , cbar_sim_no_grad_ax = pred_sim_grid [0 , 0 ].subgridspec (1 , 2 , wspace = 0.05 ,
160
+ hspace = 0 , width_ratios = [1 , 0.08 ]).subplots ()
161
+ pred_sim_grad_ax , cbar_sim_grad_ax = pred_sim_grid [1 , 0 ].subgridspec (1 , 2 , wspace = 0.05 ,
162
+ hspace = 0 , width_ratios = [1 , 0.08 ]).subplots ()
163
+ pred_exp_ax , cbar_exp_ax = left_grid [0 , 0 ].subgridspec (1 , 2 , wspace = 0.05 , width_ratios = [1 , 0.05 ]).subplots ()
164
+ afm_axes = left_grid [1 , 0 ].subgridspec (len (X_sim ), len (X_slices ), wspace = 0.01 , hspace = 0.01 ).subplots (squeeze = False )
165
+
166
+ # Plot AFM
167
+ for i , x in enumerate (X_sim ):
168
+ for j , s in enumerate (X_slices ):
169
+
170
+ # Plot AFM slice
171
+ im = afm_axes [i , j ].imshow (x [0 ,:,:,s ].T , origin = "lower" , cmap = "afmhot" )
172
+ afm_axes [i , j ].set_axis_off ()
173
+
174
+ # Put tip names to the left of the AFM image rows
175
+ afm_axes [i , 0 ].text (- 0.1 , 0.5 , tip_names [i ], horizontalalignment = "center" ,
176
+ verticalalignment = "center" , transform = afm_axes [i , 0 ].transAxes ,
177
+ rotation = "vertical" , fontsize = fontsize )
178
+
179
+ # Figure out ES data limits
180
+ vmax_sim_no_grad = max (abs (pred_sim_no_grad [0 ].min ()), abs (pred_sim_no_grad [0 ].max ()))
181
+ vmax_sim_grad = max (abs (pred_sim_grad [0 ].min ()), abs (pred_sim_grad [0 ].max ()))
182
+ vmax_exp = max (abs (pred_exp [0 ].min ()), abs (pred_exp [0 ].max ()))
183
+ vmin_sim_no_grad = - vmax_sim_no_grad
184
+ vmin_sim_grad = - vmax_sim_grad
185
+ vmin_exp = - vmax_exp
186
+
187
+ # Plot ES predictions
188
+ pred_sim_no_grad_ax .imshow (pred_sim_no_grad [0 ][0 ].T , origin = "lower" , cmap = "coolwarm" ,
189
+ vmin = vmin_sim_no_grad , vmax = vmax_sim_no_grad )
190
+ pred_sim_grad_ax .imshow (pred_sim_grad [0 ][0 ].T , origin = "lower" , cmap = "coolwarm" ,
191
+ vmin = vmin_sim_grad , vmax = vmax_sim_grad )
192
+ pred_exp_ax .imshow (pred_exp [0 ][0 ].T , origin = "lower" , cmap = "coolwarm" , vmin = vmin_exp , vmax = vmax_exp )
193
+
194
+ pred_sim_no_grad_ax .set_axis_off ()
195
+ pred_sim_grad_ax .set_axis_off ()
196
+ pred_exp_ax .set_axis_off ()
197
+
198
+ # Plot ES Map colorbar for no grad prediction
199
+ m_es = cm .ScalarMappable (cmap = cm .coolwarm )
200
+ m_es .set_array ((vmin_sim_no_grad , vmax_sim_no_grad ))
201
+ cbar = plt .colorbar (m_es , cax = cbar_sim_no_grad_ax )
202
+ cbar .set_ticks ([- 0.1 , 0.0 , 0.1 ])
203
+ cbar_sim_no_grad_ax .tick_params (labelsize = fontsize - 1 )
204
+ cbar .set_label ("V/Å" , fontsize = fontsize )
205
+
206
+ # Plot ES Map colorbar for grad prediction
207
+ m_es = cm .ScalarMappable (cmap = cm .coolwarm )
208
+ m_es .set_array ((vmin_sim_grad , vmax_sim_grad ))
209
+ cbar = plt .colorbar (m_es , cax = cbar_sim_grad_ax )
210
+ cbar .set_ticks ([- 0.1 , 0.0 , 0.1 ])
211
+ cbar_sim_grad_ax .tick_params (labelsize = fontsize - 1 )
212
+ cbar .set_label ("V/Å" , fontsize = fontsize )
213
+
214
+ # Plot ES Map colorbar for experimental prediction
215
+ m_es = cm .ScalarMappable (cmap = cm .coolwarm )
216
+ m_es .set_array ((vmin_exp , vmax_exp ))
217
+ cbar = plt .colorbar (m_es , cax = cbar_exp_ax )
218
+ cbar .set_ticks ([- 0.04 , 0.0 , 0.04 ])
219
+ cbar_exp_ax .tick_params (labelsize = fontsize - 1 )
220
+ cbar .set_label ("V/Å" , fontsize = fontsize )
221
+
222
+ # Set labels
223
+ pred_exp_ax .text (- 0.06 , 0.98 , "A" , horizontalalignment = "center" ,
224
+ verticalalignment = "center" , transform = pred_exp_ax .transAxes , fontsize = fontsize )
225
+ afm_axes [0 , 0 ].text (- 0.2 , 1.0 , "B" , horizontalalignment = "center" ,
226
+ verticalalignment = "center" , transform = afm_axes [0 , 0 ].transAxes , fontsize = fontsize )
227
+ pred_sim_no_grad_ax .text (- 0.08 , 0.98 , "C" , horizontalalignment = "center" ,
228
+ verticalalignment = "center" , transform = pred_sim_no_grad_ax .transAxes , fontsize = fontsize )
229
+ pred_sim_grad_ax .text (- 0.08 , 0.98 , "D" , horizontalalignment = "center" ,
230
+ verticalalignment = "center" , transform = pred_sim_grad_ax .transAxes , fontsize = fontsize )
231
+
232
+ plt .savefig ("background_gradient.pdf" , bbox_inches = "tight" , dpi = dpi )
0 commit comments