Skip to content

Commit 18a406b

Browse files
committed
Added generator for loading files from tar files.
1 parent 1807794 commit 18a406b

File tree

3 files changed

+221
-26
lines changed

3 files changed

+221
-26
lines changed

Diff for: docs/source/reference/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Reference
44
.. toctree::
55

66
mlspm.cli
7+
mlspm.data_generation
78
mlspm.data_loading
89
mlspm.datasets
910
mlspm.graph

Diff for: docs/source/reference/mlspm.data_generation.rst

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mlspm.data_generation
2+
=====================
3+
4+
.. automodule:: mlspm.data_generation
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+

Diff for: mlspm/data_generation.py

+212-26
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
21
import io
2+
import multiprocessing as mp
3+
import multiprocessing.shared_memory
34
import os
45
import tarfile
56
import time
67
from os import PathLike
78
from pathlib import Path
8-
from typing import List, Optional
9+
from typing import Optional, TypedDict, NotRequired
910

1011
import numpy as np
1112
from PIL import Image
1213

13-
1414
class TarWriter:
15-
'''
15+
"""
1616
Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
1717
:meth:`add_sample`.
1818
@@ -25,10 +25,10 @@ class TarWriter:
2525
base_name: Base name for output tar files. The number of the tar file is appended to the name.
2626
max_count: Maximum number of samples per tar file.
2727
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
28-
write speed.
29-
'''
28+
write speed.
29+
"""
3030

31-
def __init__(self, base_path: PathLike='./', base_name: str='', max_count: int=100, png_compress_level=4):
31+
def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, png_compress_level=4):
3232
self.base_path = Path(base_path)
3333
self.base_name = base_name
3434
self.max_count = max_count
@@ -45,12 +45,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
4545
self.ft.close()
4646

4747
def _get_tar_file(self):
48-
file_path = self.base_path / f'{self.base_name}_{self.tar_count}.tar'
48+
file_path = self.base_path / f"{self.base_name}_{self.tar_count}.tar"
4949
if os.path.exists(file_path):
50-
raise RuntimeError(f'Tar file already exists at `{file_path}`')
51-
return tarfile.open(file_path, 'w', format=tarfile.GNU_FORMAT)
50+
raise RuntimeError(f"Tar file already exists at `{file_path}`")
51+
return tarfile.open(file_path, "w", format=tarfile.GNU_FORMAT)
5252

53-
def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray]=None, comment_str: str=''):
53+
def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""):
5454
"""
5555
Add a sample to the current tar file.
5656
@@ -71,39 +71,225 @@ def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
7171
for i, x in enumerate(X):
7272
for j in range(x.shape[-1]):
7373
xj = x[:, :, j]
74-
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
74+
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
7575
img_bytes = io.BytesIO()
76-
Image.fromarray(xj.T[::-1], mode='L').save(img_bytes, 'png', compress_level=self.png_compress_level)
77-
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
78-
self.ft.addfile(get_tarinfo(f'{self.total_count}.{j:02d}.{i}.png', img_bytes), img_bytes)
76+
Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png", compress_level=self.png_compress_level)
77+
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
78+
self.ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes)
7979
img_bytes.close()
80-
80+
8181
# Write xyz file
8282
xyz_bytes = io.BytesIO()
83-
xyz_bytes.write(bytearray(f'{len(xyzs)}\n{comment_str}\n', 'utf-8'))
83+
xyz_bytes.write(bytearray(f"{len(xyzs)}\n{comment_str}\n", "utf-8"))
8484
for xyz in xyzs:
85-
xyz_bytes.write(bytearray(f'{int(xyz[-1])}\t', 'utf-8'))
86-
for i in range(len(xyz)-1):
87-
xyz_bytes.write(bytearray(f'{xyz[i]:10.8f}\t', 'utf-8'))
88-
xyz_bytes.write(bytearray('\n', 'utf-8'))
89-
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
90-
self.ft.addfile(get_tarinfo(f'{self.total_count}.xyz', xyz_bytes), xyz_bytes)
85+
xyz_bytes.write(bytearray(f"{int(xyz[-1])}\t", "utf-8"))
86+
for i in range(len(xyz) - 1):
87+
xyz_bytes.write(bytearray(f"{xyz[i]:10.8f}\t", "utf-8"))
88+
xyz_bytes.write(bytearray("\n", "utf-8"))
89+
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
90+
self.ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes)
9191
xyz_bytes.close()
9292

9393
# Write image descriptors (if any)
9494
if Y is not None:
9595
for i, y in enumerate(Y):
9696
img_bytes = io.BytesIO()
9797
np.save(img_bytes, y.astype(np.float32))
98-
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
99-
self.ft.addfile(get_tarinfo(f'{self.total_count}.desc_{i}.npy', img_bytes), img_bytes)
98+
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
99+
self.ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes)
100100
img_bytes.close()
101101

102102
self.sample_count += 1
103103
self.total_count += 1
104104

105+
105106
def get_tarinfo(fname: str, file_bytes: io.BytesIO):
106107
info = tarfile.TarInfo(fname)
107108
info.size = file_bytes.getbuffer().nbytes
108109
info.mtime = time.time()
109-
return info
110+
return info
111+
112+
class TarSample(TypedDict):
113+
"""
114+
- ``'hartree'``: Path to the Hartree potential. First item in the tuple is the path
115+
to the tar file relative to ``base_path``, and second entry is the tar file member name.
116+
- ``'rho'``: (Optional) Path to the electron density. First item in the tuple is the path
117+
to the tar file relative to ``base_path``, and second entry is the tar file member name.
118+
- ``'rots'``: List of rotations to generate for the sample.
119+
"""
120+
hartree: tuple[str, str]
121+
rho: NotRequired[tuple[str, str]]
122+
rots: list[np.ndarray]
123+
124+
class TarDataGenerator:
125+
"""
126+
Iterable that loads data from tar archives with data saved in npz format for generating samples
127+
with the GeneratorAFMTrainer in ppafm.
128+
129+
The npz files should contain the following entries:
130+
131+
- ``'array'``: An array containing the potential/density on a 3D grid.
132+
- ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``.
133+
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
134+
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
135+
- ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``.
136+
137+
Arguments:
138+
samples: List of sample dicts as :class:`TarSample`. If ``rho`` is present in the dict, the full-density-based model
139+
is used in the simulation. Otherwise Lennard-Jones with Hartree electrostatics is used.
140+
base_path: Path to the directory with the tar files.
141+
n_proc: Number of parallel processes for loading data. The samples get divided evenly over the processes.
142+
"""
143+
144+
_timings = False
145+
146+
def __init__(self, samples: list[TarSample], base_path: PathLike = "./", n_proc: int = 1):
147+
self.samples = samples
148+
self.base_path = Path(base_path)
149+
self.n_proc = n_proc
150+
151+
def __len__(self):
152+
"""Total number of samples (including rotations)"""
153+
return sum([len(s["rots"]) for s in self.samples])
154+
155+
def _launch_procs(self):
156+
self.q = mp.Queue(maxsize=self.n_proc)
157+
self.events = []
158+
samples_split = np.array_split(self.samples, self.n_proc)
159+
for i in range(self.n_proc):
160+
event = mp.Event()
161+
p = mp.Process(target=self._load_samples, args=(samples_split[i], i, event))
162+
p.start()
163+
self.events.append(event)
164+
165+
def __iter__(self):
166+
self._launch_procs()
167+
self.iterator = iter(self._yield_samples())
168+
return self
169+
170+
def __next__(self):
171+
return next(self.iterator)
172+
173+
def _get_data(self, tar_path: PathLike, name: str):
174+
tar_path = self.base_path / tar_path
175+
with tarfile.open(tar_path, "r") as f:
176+
data = np.load(f.extractfile(name))
177+
array = data["data"]
178+
origin = data["origin"]
179+
lattice = data["lattice"]
180+
xyzs = data["xyz"]
181+
Zs = data["Z"]
182+
lvec = np.concatenate([origin[None, :], lattice], axis=0)
183+
return array, lvec, xyzs, Zs
184+
185+
def _load_samples(self, samples: list[TarSample], i_proc: int, event: mp.Event):
186+
187+
proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:]
188+
print(f"Starting worker {i_proc}, id {proc_id}")
189+
190+
for i, sample in enumerate(samples):
191+
192+
if self._timings:
193+
t0 = time.perf_counter()
194+
195+
# Load data from tar(s)
196+
rots = sample["rots"]
197+
hartree_tar_path, name = sample["hartree"]
198+
pot, lvec, xyzs, Zs = self._get_data(hartree_tar_path, name)
199+
pot *= -1 # Unit conversion, eV -> V
200+
if "rho" in sample:
201+
rho_tar_path, name = sample["rho"]
202+
rho, _, _, _ = self._get_data(rho_tar_path, name)
203+
204+
if self._timings:
205+
t1 = time.perf_counter()
206+
207+
# Put the data to shared memory
208+
sample_id_pot = f"{i_proc}_{proc_id}_{i}_pot"
209+
shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot)
210+
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
211+
b[:] = pot[:]
212+
213+
if "rho" in sample:
214+
sample_id_rho = f"{i_proc}_{proc_id}_{i}__rho"
215+
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
216+
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
217+
b[:] = rho[:]
218+
rho_shape = rho.shape
219+
else:
220+
sample_id_rho = None
221+
rho_shape = None
222+
223+
if self._timings:
224+
t2 = time.perf_counter()
225+
226+
# Inform the main process of the data using the queue
227+
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots))
228+
229+
if self._timings:
230+
t3 = time.perf_counter()
231+
232+
# Wait until main process is done with the data
233+
if not event.wait(timeout=60):
234+
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
235+
event.clear()
236+
237+
if self._timings:
238+
t4 = time.perf_counter()
239+
240+
# Done with shared memory
241+
shm_pot.unlink()
242+
shm_pot.close()
243+
if "rho" in sample:
244+
shm_rho.unlink()
245+
shm_rho.close()
246+
247+
if self._timings:
248+
t5 = time.perf_counter()
249+
print(
250+
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: "
251+
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}"
252+
)
253+
254+
def _yield_samples(self):
255+
256+
from ppafm.ocl.field import ElectronDensity, HartreePotential
257+
258+
for _ in range(len(self)):
259+
260+
if self._timings:
261+
t0 = time.perf_counter()
262+
263+
# Get data info from the queue
264+
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec, xyzs, Zs, rots = self.q.get(timeout=200)
265+
266+
# Get data from the shared memory
267+
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
268+
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
269+
pot = HartreePotential(pot, lvec)
270+
if sample_id_rho is not None:
271+
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
272+
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
273+
rho = ElectronDensity(rho, lvec)
274+
else:
275+
rho = None
276+
277+
if self._timings:
278+
t1 = time.perf_counter()
279+
280+
for rot in rots:
281+
sample_dict = {"xyzs": xyzs, "Zs": Zs, "qs": pot, "rho_sample": rho, "rot": rot}
282+
yield sample_dict
283+
284+
if self._timings:
285+
t2 = time.perf_counter()
286+
287+
# Close shared memory and inform producer that the shared memory can be unlinked
288+
shm_pot.close()
289+
if sample_id_rho is not None:
290+
shm_rho.close()
291+
self.events[i_proc].set()
292+
293+
if self._timings:
294+
t3 = time.perf_counter()
295+
print(f"[Main, id {sample_id_pot}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")

0 commit comments

Comments
 (0)