Skip to content

Commit c639611

Browse files
committed
Option for asynchronous writes in TarWriter.
1 parent b0b5b2a commit c639611

File tree

3 files changed

+70
-36
lines changed

3 files changed

+70
-36
lines changed

mlspm/data_generation.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import io
22
import multiprocessing as mp
33
import os
4+
import queue
45
import tarfile
56
import time
67
from multiprocessing.shared_memory import SharedMemory
78
from os import PathLike
89
from pathlib import Path
910
from typing import Optional, TypedDict
11+
import warnings
1012

1113
import numpy as np
1214
from PIL import Image
@@ -26,58 +28,81 @@ class TarWriter:
2628
base_path: Path to directory where tar files are saved.
2729
base_name: Base name for output tar files. The number of the tar file is appended to the name.
2830
max_count: Maximum number of samples per tar file.
29-
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
30-
write speed.
31+
async_write: Write tar files asynchronously in a parallel process.
3132
"""
3233

33-
def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, png_compress_level=4):
34+
def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, async_write=True):
3435
self.base_path = Path(base_path)
3536
self.base_name = base_name
3637
self.max_count = max_count
37-
self.png_compress_level = png_compress_level
38+
self.async_write = async_write
3839

3940
def __enter__(self):
4041
self.sample_count = 0
4142
self.total_count = 0
4243
self.tar_count = 0
43-
self.ft = self._get_tar_file()
44+
if self.async_write:
45+
self._launch_write_process()
46+
else:
47+
self._ft = self._get_tar_file()
4448
return self
4549

4650
def __exit__(self, exc_type, exc_value, exc_traceback):
47-
self.ft.close()
51+
if self.async_write:
52+
self._event_done.set()
53+
if not self._event_tar_close.wait(60):
54+
warnings.warn("Write process did not respond within timeout period. Last tar file may not have been closed properly.")
55+
else:
56+
self._ft.close()
57+
58+
def _launch_write_process(self):
59+
self._q = mp.Queue(1)
60+
self._event_done = mp.Event()
61+
self._event_tar_close = mp.Event()
62+
p = mp.Process(target=self._write_async)
63+
p.start()
64+
65+
def _write_async(self):
66+
self._ft = self._get_tar_file()
67+
try:
68+
while True:
69+
try:
70+
sample = self._q.get(block=False)
71+
self._add_sample(*sample)
72+
continue
73+
except queue.Empty:
74+
pass
75+
if self._event_done.is_set() and self._q.empty():
76+
self._ft.close()
77+
self._event_tar_close.set()
78+
return
79+
except:
80+
self._ft.close()
81+
self._event_tar_close.set()
4882

4983
def _get_tar_file(self):
5084
file_path = self.base_path / f"{self.base_name}_{self.tar_count}.tar"
5185
if os.path.exists(file_path):
5286
raise RuntimeError(f"Tar file already exists at `{file_path}`")
5387
return tarfile.open(file_path, "w", format=tarfile.GNU_FORMAT)
5488

55-
def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""):
56-
"""
57-
Add a sample to the current tar file.
58-
59-
Arguments:
60-
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
61-
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
62-
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
63-
comment_str: Comment line (second line) to add to the xyz file.
64-
"""
89+
def _add_sample(self, X, xyzs, Y, comment_str):
6590

6691
if self.sample_count >= self.max_count:
6792
self.tar_count += 1
6893
self.sample_count = 0
69-
self.ft.close()
70-
self.ft = self._get_tar_file()
94+
self._ft.close()
95+
self._ft = self._get_tar_file()
7196

7297
# Write AFM images
7398
for i, x in enumerate(X):
7499
for j in range(x.shape[-1]):
75100
xj = x[:, :, j]
76101
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
77102
img_bytes = io.BytesIO()
78-
Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png", compress_level=self.png_compress_level)
103+
Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png")
79104
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
80-
self.ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes)
105+
self._ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes)
81106
img_bytes.close()
82107

83108
# Write xyz file
@@ -89,7 +114,7 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
89114
xyz_bytes.write(bytearray(f"{xyz[i]:10.8f}\t", "utf-8"))
90115
xyz_bytes.write(bytearray("\n", "utf-8"))
91116
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
92-
self.ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes)
117+
self._ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes)
93118
xyz_bytes.close()
94119

95120
# Write image descriptors (if any)
@@ -98,12 +123,27 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
98123
img_bytes = io.BytesIO()
99124
np.save(img_bytes, y.astype(np.float32))
100125
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
101-
self.ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes)
126+
self._ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes)
102127
img_bytes.close()
103128

104129
self.sample_count += 1
105130
self.total_count += 1
106131

132+
def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""):
133+
"""
134+
Add a sample to the current tar file.
135+
136+
Arguments:
137+
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
138+
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
139+
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
140+
comment_str: Comment line (second line) to add to the xyz file.
141+
"""
142+
if self.async_write:
143+
self._q.put((X, xyzs, Y, comment_str), block=True, timeout=60)
144+
else:
145+
self._add_sample(X, xyzs, Y, comment_str)
146+
107147

108148
def get_tarinfo(fname: str, file_bytes: io.BytesIO):
109149
info = tarfile.TarInfo(fname)
@@ -128,13 +168,12 @@ class TarSampleList(TypedDict, total=False):
128168

129169
class TarDataGenerator:
130170
"""
131-
Iterable that loads data from tar archives with data saved in npz format for generating samples
132-
with the GeneratorAFMTrainer in ppafm.
171+
Iterable that loads data from tar archives with data saved in npz format for generating samples with ``GeneratorAFMTrainer``
172+
in *ppafm*.
133173
134174
The npz files should contain the following entries:
135175
136-
- ``'data'``: An array containing the potential/density on a 3D grid. The potential is assumed to be in
137-
units of eV and density in units of e/Å^3.
176+
- ``'data'``: An array containing the potential/density on a 3D grid.
138177
- ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``.
139178
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
140179
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
@@ -148,8 +187,9 @@ class TarDataGenerator:
148187
- ``'rho_sample'``: Sample electron density if the sample dict contained ``rho``, or ``None`` otherwise.
149188
- ``'rot'``: Rotation matrix.
150189
151-
Note: it is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
152-
Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
190+
Note:
191+
It is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
192+
Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
153193
154194
Arguments:
155195
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.

papers/asd-afm/generate_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,7 @@ def on_sample_start(self):
116116
if i % 100 == 0:
117117
elapsed = time.perf_counter() - start_gen
118118
eta = elapsed / (i + 1) * (len(generator) - i)
119-
print(
120-
f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, "
121-
f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s"
122-
)
119+
print(f"{mode} sample {i}/{len(generator)}, Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s")
123120

124121
print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}")
125122

papers/ed-afm/generate_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ def handle_distance(self):
107107
if i % 100 == 0:
108108
elapsed = time.perf_counter() - start_gen
109109
eta = elapsed / (i + 1) * (len(generator) - i)
110-
print(
111-
f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, "
112-
f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s"
113-
)
110+
print(f"{mode} sample {i}/{len(generator)}, Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s")
114111

115112
print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}")
116113

0 commit comments

Comments
 (0)