Skip to content

Commit df43882

Browse files
authored
Relaxed ice structures and plotting scripts. (#37)
1 parent 3fe7020 commit df43882

File tree

8 files changed

+443
-6
lines changed

8 files changed

+443
-6
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- torchvision
1212
- torchaudio
1313
- cuda
14+
- ninja
1415
- matplotlib
1516
- numpy
1617
- scipy

mlspm/datasets.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"AFM-ice-Au111-monolayer": "https://zenodo.org/records/10049832/files/AFM-ice-Au111-monolayer.tar.gz?download=1",
1313
"AFM-ice-Au111-bilayer": "https://zenodo.org/records/10049856/files/AFM-ice-Au111-bilayer.tar.gz?download=1",
1414
"AFM-ice-exp": "https://zenodo.org/records/10054847/files/exp_data_ice.tar.gz?download=1",
15+
"AFM-ice-relaxed": "https://zenodo.org/records/10362511/files/relaxed_structures.tar.gz?download=1",
1516
}
1617

1718

@@ -29,6 +30,16 @@ def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
2930
raise Exception("Attempted Path Traversal in Tar File")
3031
tar.extractall(path, members, numeric_owner=numeric_owner)
3132

33+
def _common_parent(paths):
34+
path_parts = [list(Path(p).parts) for p in paths]
35+
common_part = Path()
36+
for parts in zip(*path_parts):
37+
p = parts[0]
38+
if all(part == p for part in parts):
39+
common_part /= p
40+
else:
41+
break
42+
return common_part
3243

3344
def download_dataset(name: str, target_dir: PathLike):
3445
"""
@@ -40,6 +51,7 @@ def download_dataset(name: str, target_dir: PathLike):
4051
- ``'AFM-ice-Au111-monolayer'``: https://doi.org/10.5281/zenodo.10049832
4152
- ``'AFM-ice-Au111-bilayer'``: https://doi.org/10.5281/zenodo.10049856
4253
- ``'AFM-ice-exp'``: https://doi.org/10.5281/zenodo.10054847
54+
- ``'AFM-ice-relaxed'``: https://doi.org/10.5281/zenodo.10362511
4355
4456
Arguments:
4557
name: Name of dataset to download.
@@ -64,7 +76,7 @@ def download_dataset(name: str, target_dir: PathLike):
6476
with tarfile.open(temp_file, "r") as ft:
6577
print("Reading archive files...")
6678
members = []
67-
base_dir = os.path.commonprefix(ft.getnames())
79+
base_dir = _common_parent(ft.getnames())
6880
for m in ft.getmembers():
6981
if m.isfile():
7082
# relative_to(base_dir) here gets rid of a common parent directory within the archive (if any),

papers/ice_structure_discovery/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This folder contains the source code and links to the datasets that were used fo
1212

1313
The subdirectories contain various scripts for training and running predictions with the models:
1414
- `training`: Scripts for training the atom position and graph construction models, and evaluating the trained models.
15-
- `prediction`: Scripts for reproducing the result in Fig. 2 of the paper using the pretrained models.
15+
- `predictions`: Scripts for reproducing the results figures of the paper using the pretrained models.
1616

1717
## Data
1818

@@ -25,4 +25,6 @@ Training datasets:
2525

2626
Experimental data: https://doi.org/10.5281/zenodo.10054847
2727

28+
Final relaxed geometries: https://doi.org/10.5281/zenodo.10362511
29+
2830
Pretrained weights for the models: https://doi.org/10.5281/zenodo.10054348
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
The scripts here can be used to reproduce the results in Fig. 2 of the paper.
1+
The scripts here can be used to reproduce the results figures in the paper.
22
- `predict_experiments.py`: Runs the prediction for all of the experimental AFM images of ice on Cu(111) and Au(111) using the three models pretrained on the Cu(111), Au(111)-monolayer, and Au(111)-bilayer datasets, and saves them on disk.
3-
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure.
3+
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure as in Fig. 2 of the paper.
4+
- `plot_relaxed_structures.py`: Plots the on-surface structures relaxed with a neural network potential and DFT as well as the corresponding simulations and experimental images as in Fig. 3 of the paper.
5+
- `plot_prediction_extra.py`: Plots the prediction and the relaxed structure with corresponding simulations and experimental images for the one extra ice cluster not in the main results figure.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#!/usr/bin/env python3
2+
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
from torch import scatter
7+
from ppafm.ocl.oclUtils import init_env
8+
9+
from plot_predictions import get_data as get_data_prediction, MM_TO_INCH
10+
from plot_predictions import plot_graph as plot_graph_prediction
11+
from plot_relaxed_structures import get_data as get_data_relaxed
12+
from plot_relaxed_structures import plot_graph as plot_graph_relaxed
13+
14+
# # Set matplotlib font rendering to use LaTex
15+
# plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})
16+
17+
18+
def init_fig(width=140, left_margin=4, top_margin=4, row_gap=6, gap=0.5):
19+
ax_size = (width - left_margin - 5 * gap) / 5
20+
21+
left_margin *= MM_TO_INCH
22+
top_margin *= MM_TO_INCH
23+
row_gap *= MM_TO_INCH
24+
gap *= MM_TO_INCH
25+
ax_size *= MM_TO_INCH
26+
width *= MM_TO_INCH
27+
height = top_margin + 2 * (ax_size + gap) + row_gap
28+
fig = plt.figure(figsize=(width, height))
29+
30+
axes = []
31+
32+
y = height - top_margin - ax_size
33+
x = left_margin
34+
axes_ = []
35+
for _ in range(5):
36+
rect = [x / width, y / height, ax_size / width, ax_size / height]
37+
ax = fig.add_axes(rect)
38+
ax.set_xticks([])
39+
ax.set_yticks([])
40+
for axis in ["top", "bottom", "left", "right"]:
41+
ax.spines[axis].set_linewidth(0.5)
42+
axes_.append(ax)
43+
x += ax_size + gap
44+
axes.append(axes_)
45+
46+
y = height - top_margin - 2 * ax_size - row_gap
47+
x = left_margin + 2 * (ax_size + gap)
48+
axes_ = []
49+
for _ in range(3):
50+
rect = [x / width, y / height, ax_size / width, ax_size / height]
51+
ax = fig.add_axes(rect)
52+
ax.set_xticks([])
53+
ax.set_yticks([])
54+
for axis in ["top", "bottom", "left", "right"]:
55+
ax.spines[axis].set_linewidth(0.5)
56+
axes_.append(ax)
57+
x += ax_size + gap
58+
axes.append(axes_)
59+
60+
return fig, axes
61+
62+
63+
if __name__ == "__main__":
64+
init_env(i_platform=0)
65+
66+
exp_data_dir = Path("./exp_data")
67+
sim_data_dir = Path("./relaxed_structures/")
68+
scatter_size = 5
69+
zmin = -5.0
70+
zmax = 0.5
71+
classes = [[1], [8], [29, 79]]
72+
class_colors = ["w", "r"]
73+
fontsize = 7
74+
75+
params = {
76+
"pred_dir": "predictions_au111-bilayer",
77+
"sim_name": "hartree_I",
78+
"exp_name": "Ying_Jiang_4",
79+
"label": "I",
80+
"dist": 4.8,
81+
"rot_angle": -25.000,
82+
"amp": 2.0,
83+
"nz": 7,
84+
"offset": (0.0, 0.0),
85+
}
86+
87+
exp_data, pred_mol, sim_pred = get_data_prediction(params, exp_data_dir, classes)
88+
opt_mol, sim_opt, _, sw_opt = get_data_relaxed(params, exp_data_dir, sim_data_dir, classes)
89+
90+
fig, axes = init_fig()
91+
92+
# Plot data
93+
axes[0][0].imshow(exp_data['data'][:, :, 0].T, origin="lower", cmap="gray")
94+
axes[0][1].imshow(exp_data['data'][:, :, -1].T, origin="lower", cmap="gray")
95+
plot_graph_prediction(
96+
axes[0][2],
97+
pred_mol,
98+
box_borders=[[0, 0, zmin], [exp_data["lengthX"], exp_data["lengthY"], zmax]],
99+
zmin=zmin,
100+
zmax=zmax,
101+
scatter_size=scatter_size,
102+
class_colors=class_colors,
103+
)
104+
axes[0][3].imshow(sim_pred[:, :, 0].T, origin="lower", cmap="gray")
105+
axes[0][4].imshow(sim_pred[:, :, -1].T, origin="lower", cmap="gray")
106+
plot_graph_relaxed(
107+
axes[1][0],
108+
opt_mol,
109+
box_borders=[[sw_opt[0][0], sw_opt[0][1], zmin], [sw_opt[1][0], sw_opt[1][1], zmax]],
110+
zmin=zmin,
111+
zmax=zmax,
112+
scatter_size=scatter_size,
113+
class_colors=class_colors,
114+
)
115+
axes[1][1].imshow(sim_opt[:, :, 0].T, origin="lower", cmap="gray")
116+
axes[1][2].imshow(sim_opt[:, :, -1].T, origin="lower", cmap="gray")
117+
118+
# Set labels
119+
y = 1.08
120+
axes[0][0].text(
121+
-0.08, 0.5, params["label"], transform=axes[0][0].transAxes, fontsize=fontsize, va="center", ha="center", rotation="vertical"
122+
)
123+
axes[0][0].text(0.5, y, "Exp.\ AFM (far)", transform=axes[0][0].transAxes, fontsize=fontsize, va="center", ha="center")
124+
axes[0][1].text(0.5, y, "Exp.\ AFM (close)", transform=axes[0][1].transAxes, fontsize=fontsize, va="center", ha="center")
125+
axes[0][2].text(0.5, y, "Pred.\ geom.", transform=axes[0][2].transAxes, fontsize=fontsize, va="center", ha="center")
126+
axes[0][3].text(0.5, y, "Sim.\ AFM (far)", transform=axes[0][3].transAxes, fontsize=fontsize, va="center", ha="center")
127+
axes[0][4].text(0.5, y, "Sim.\ AFM (close)", transform=axes[0][4].transAxes, fontsize=fontsize, va="center", ha="center")
128+
axes[1][0].text(0.5, y, "Opt.\ geom.", transform=axes[1][0].transAxes, fontsize=fontsize, va="center", ha="center")
129+
axes[1][1].text(0.5, y, "Sim.\ AFM (far)", transform=axes[1][1].transAxes, fontsize=fontsize, va="center", ha="center")
130+
axes[1][2].text(0.5, y, "Sim.\ AFM (close)", transform=axes[1][2].transAxes, fontsize=fontsize, va="center", ha="center")
131+
132+
plt.savefig(f"sims_extra.png", dpi=400)

papers/ice_structure_discovery/predictions/plot_predictions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def plot_graph(ax, mol, box_borders, class_colors, scatter_size, zmin, zmax):
199199
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_2_2', 'label': 'E', 'dist': 4.9, 'offset': ( 0.0, 0.0)},
200200
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_3' , 'label': 'F', 'dist': 4.8, 'offset': ( 0.0, -2.0)},
201201
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_5' , 'label': 'G', 'dist': 5.0, 'offset': ( 2.0, 0.0)},
202-
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_6' , 'label': 'H', 'dist': 4.8, 'offset': ( 1.5, 2.0)}
202+
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_6' , 'label': 'H', 'dist': 4.8, 'offset': ( 1.5, 2.0)},
203+
# {'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_4' , 'label': 'I', 'dist': 4.8, 'offset': ( 0.0, 0.0)}
203204
]
204205

205206
data = [get_data(p, exp_data_dir, classes) for p in params]

0 commit comments

Comments
 (0)