Skip to content

Commit c977849

Browse files
authored
Paper code: ASD-AFM (#39)
1 parent 0eba5b4 commit c977849

29 files changed

+1203
-136
lines changed

.gitignore

+4-80
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ share/python-wheels/
2727
*.egg
2828
MANIFEST
2929

30-
# PyInstaller
31-
# Usually these files are written by a python script from a template
32-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33-
*.manifest
34-
*.spec
35-
3630
# Installer logs
3731
pip-log.txt
3832
pip-delete-this-directory.txt
@@ -52,26 +46,6 @@ coverage.xml
5246
.pytest_cache/
5347
cover/
5448

55-
# Translations
56-
*.mo
57-
*.pot
58-
59-
# Django stuff:
60-
*.log
61-
local_settings.py
62-
db.sqlite3
63-
db.sqlite3-journal
64-
65-
# Flask stuff:
66-
instance/
67-
.webassets-cache
68-
69-
# Scrapy stuff:
70-
.scrapy
71-
72-
# Sphinx documentation
73-
docs/_build/
74-
7549
# PyBuilder
7650
.pybuilder/
7751
target/
@@ -83,43 +57,9 @@ target/
8357
profile_default/
8458
ipython_config.py
8559

86-
# pyenv
87-
# For a library or package, you might want to ignore these files since the code is
88-
# intended to run in multiple environments; otherwise, check them in:
89-
# .python-version
90-
91-
# pipenv
92-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
94-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
95-
# install all needed dependencies.
96-
#Pipfile.lock
97-
98-
# poetry
99-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100-
# This is especially recommended for binary packages to ensure reproducibility, and is more
101-
# commonly ignored for libraries.
102-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103-
#poetry.lock
104-
105-
# pdm
106-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107-
#pdm.lock
108-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109-
# in version control.
110-
# https://pdm.fming.dev/#use-with-ide
111-
.pdm.toml
112-
11360
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
11461
__pypackages__/
11562

116-
# Celery stuff
117-
celerybeat-schedule
118-
celerybeat.pid
119-
120-
# SageMath parsed files
121-
*.sage.py
122-
12363
# Environments
12464
.env
12565
.venv
@@ -133,29 +73,13 @@ venv.bak/
13373
.spyderproject
13474
.spyproject
13575

136-
# Rope project settings
137-
.ropeproject
138-
139-
# mkdocs documentation
140-
/site
141-
14276
# mypy
14377
.mypy_cache/
14478
.dmypy.json
14579
dmypy.json
14680

147-
# Pyre type checker
148-
.pyre/
149-
150-
# pytype static type analyzer
151-
.pytype/
152-
153-
# Cython debug symbols
154-
cython_debug/
81+
# VS Code
82+
.vscode
15583

156-
# PyCharm
157-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159-
# and can be added to the global gitignore or merged into this file. For a more nuclear
160-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
161-
#.idea/
84+
# Other
85+
molecules

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ pip install .
2929

3030
## Papers
3131
The [`papers`](papers) subdirectory contains training scripts and datasets for specific publications. Currently we have the following:
32+
- [Automated structure discovery in atomic force microscopy](papers/asd-afm)
3233
- [Structure discovery in Atomic Force Microscopy imaging of ice](papers/ice_structure_discovery)

docs/source/reference/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Reference
77
mlspm.data_loading
88
mlspm.datasets
99
mlspm.graph
10+
mlspm.image
1011
mlspm.logging
1112
mlspm.losses
1213
mlspm.models

docs/source/reference/mlspm.image.rst

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
mlspm.image
2+
===========
3+
4+
.. automodule:: mlspm.image
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
mlspm.image.models
10+
------------------
11+
12+
.. automodule:: mlspm.image.models
13+
:members:
14+
:undoc-members:
15+
:show-inheritance:

docs/source/reference/mlspm.models.rst

+4
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ mlspm.models
1313

1414
Alias of :class:`mlspm.graph.models.GraphImgNetIce`
1515

16+
.. class:: mlspm.models.ASDAFMNet
17+
18+
Alias of :class:`mlspm.image.models.ASDAFMNet`
19+
1620
.. autofunction:: mlspm.models.download_weights

docs/source/reference/mlspm.visualization.rst

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mlspm.visualization
66
:undoc-members:
77
:show-inheritance:
88

9+
.. autofunction:: mlspm.visualization.make_prediction_plots
10+
911
.. autofunction:: mlspm.visualization.plot_distribution_grid
1012

1113
.. autofunction:: mlspm.visualization.plot_graphs

mlspm/_weights.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"graph-ice-cu111": "https://zenodo.org/records/10054348/files/weights_ice-cu111.pth?download=1",
1010
"graph-ice-au111-monolayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-monolayer.pth?download=1",
1111
"graph-ice-au111-bilayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-bilayer.pth?download=1",
12+
"asdafm-light": "https://zenodo.org/records/10514470/files/weights_asdafm_light.pth?download=1",
13+
"asdafm-heavy": "https://zenodo.org/records/10514470/files/weights_asdafm_heavy.pth?download=1",
1214
}
1315

1416

@@ -18,18 +20,20 @@ def download_weights(weights_name: str, target_path: Optional[PathLike] = None)
1820
1921
The following weights are available:
2022
21-
- ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111).
22-
- ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111).
23-
- ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111).
23+
- ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). (https://doi.org/10.5281/zenodo.10054348)
24+
- ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348)
25+
- ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348)
26+
- ``'asdafm-light'``: ASDAFMNet trained on molecules containing the elements H, C, N, O, and F. (https://doi.org/10.5281/zenodo.10514470)
27+
- ``'asdafm-heavy'``: ASDAFMNet trained on molecules additionally containing Si, P, S, Cl, and Br. (https://doi.org/10.5281/zenodo.10514470)
28+
2429
2530
Arguments:
2631
weights_name: Name of weights to download.
2732
target_path: Path where the weights file will be saved. If specified, the parent directory for the file has to exists.
28-
If not specified, a location in cache directory is chosen. If the target file already exists, the download is skipped
33+
If not specified, a location in a cache directory is chosen. If the target file already exists, the download is skipped
2934
3035
Returns:
3136
Path where the weights were saved.
32-
3337
"""
3438
try:
3539
weights_url = WEIGHTS_URLS[weights_name]

mlspm/data_generation.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
2+
import io
3+
import os
4+
import tarfile
5+
import time
6+
from os import PathLike
7+
from pathlib import Path
8+
from typing import List, Optional
9+
10+
import numpy as np
11+
from PIL import Image
12+
13+
14+
class TarWriter:
15+
'''
16+
Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
17+
:meth:`add_sample`.
18+
19+
Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created.
20+
The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file
21+
handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends.
22+
23+
Arguments:
24+
base_path: Path to directory where tar files are saved.
25+
base_name: Base name for output tar files. The number of the tar file is appended to the name.
26+
max_count: Maximum number of samples per tar file.
27+
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
28+
write speed.
29+
'''
30+
31+
def __init__(self, base_path: PathLike='./', base_name: str='', max_count: int=100, png_compress_level=4):
32+
self.base_path = Path(base_path)
33+
self.base_name = base_name
34+
self.max_count = max_count
35+
self.png_compress_level = png_compress_level
36+
37+
def __enter__(self):
38+
self.sample_count = 0
39+
self.total_count = 0
40+
self.tar_count = 0
41+
self.ft = self._get_tar_file()
42+
return self
43+
44+
def __exit__(self, exc_type, exc_value, exc_traceback):
45+
self.ft.close()
46+
47+
def _get_tar_file(self):
48+
file_path = self.base_path / f'{self.base_name}_{self.tar_count}.tar'
49+
if os.path.exists(file_path):
50+
raise RuntimeError(f'Tar file already exists at `{file_path}`')
51+
return tarfile.open(file_path, 'w', format=tarfile.GNU_FORMAT)
52+
53+
def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray]=None, comment_str: str=''):
54+
"""
55+
Add a sample to the current tar file.
56+
57+
Arguments:
58+
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
59+
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
60+
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
61+
comment_str: Comment line (second line) to add to the xyz file.
62+
"""
63+
64+
if self.sample_count >= self.max_count:
65+
self.tar_count += 1
66+
self.sample_count = 0
67+
self.ft.close()
68+
self.ft = self._get_tar_file()
69+
70+
# Write AFM images
71+
for i, x in enumerate(X):
72+
for j in range(x.shape[-1]):
73+
xj = x[:, :, j]
74+
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
75+
img_bytes = io.BytesIO()
76+
Image.fromarray(xj.T[::-1], mode='L').save(img_bytes, 'png', compress_level=self.png_compress_level)
77+
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
78+
self.ft.addfile(get_tarinfo(f'{self.total_count}.{j:02d}.{i}.png', img_bytes), img_bytes)
79+
img_bytes.close()
80+
81+
# Write xyz file
82+
xyz_bytes = io.BytesIO()
83+
xyz_bytes.write(bytearray(f'{len(xyzs)}\n{comment_str}\n', 'utf-8'))
84+
for xyz in xyzs:
85+
xyz_bytes.write(bytearray(f'{int(xyz[-1])}\t', 'utf-8'))
86+
for i in range(len(xyz)-1):
87+
xyz_bytes.write(bytearray(f'{xyz[i]:10.8f}\t', 'utf-8'))
88+
xyz_bytes.write(bytearray('\n', 'utf-8'))
89+
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
90+
self.ft.addfile(get_tarinfo(f'{self.total_count}.xyz', xyz_bytes), xyz_bytes)
91+
xyz_bytes.close()
92+
93+
# Write image descriptors (if any)
94+
if Y is not None:
95+
for i, y in enumerate(Y):
96+
img_bytes = io.BytesIO()
97+
np.save(img_bytes, y.astype(np.float32))
98+
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
99+
self.ft.addfile(get_tarinfo(f'{self.total_count}.desc_{i}.npy', img_bytes), img_bytes)
100+
img_bytes.close()
101+
102+
self.sample_count += 1
103+
self.total_count += 1
104+
105+
def get_tarinfo(fname: str, file_bytes: io.BytesIO):
106+
info = tarfile.TarInfo(fname)
107+
info.size = file_bytes.getbuffer().nbytes
108+
info.mtime = time.time()
109+
return info

0 commit comments

Comments
 (0)