1
1
import io
2
2
import multiprocessing as mp
3
3
import os
4
+ import queue
4
5
import tarfile
5
6
import time
6
7
from multiprocessing .shared_memory import SharedMemory
7
8
from os import PathLike
8
9
from pathlib import Path
9
10
from typing import Optional , TypedDict
11
+ import warnings
10
12
11
13
import numpy as np
12
14
from PIL import Image
@@ -26,58 +28,81 @@ class TarWriter:
26
28
base_path: Path to directory where tar files are saved.
27
29
base_name: Base name for output tar files. The number of the tar file is appended to the name.
28
30
max_count: Maximum number of samples per tar file.
29
- png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
30
- write speed.
31
+ async_write: Write tar files asynchronously in a parallel process.
31
32
"""
32
33
33
- def __init__ (self , base_path : PathLike = "./" , base_name : str = "" , max_count : int = 100 , png_compress_level = 4 ):
34
+ def __init__ (self , base_path : PathLike = "./" , base_name : str = "" , max_count : int = 100 , async_write = True ):
34
35
self .base_path = Path (base_path )
35
36
self .base_name = base_name
36
37
self .max_count = max_count
37
- self .png_compress_level = png_compress_level
38
+ self .async_write = async_write
38
39
39
40
def __enter__ (self ):
40
41
self .sample_count = 0
41
42
self .total_count = 0
42
43
self .tar_count = 0
43
- self .ft = self ._get_tar_file ()
44
+ if self .async_write :
45
+ self ._launch_write_process ()
46
+ else :
47
+ self ._ft = self ._get_tar_file ()
44
48
return self
45
49
46
50
def __exit__ (self , exc_type , exc_value , exc_traceback ):
47
- self .ft .close ()
51
+ if self .async_write :
52
+ self ._event_done .set ()
53
+ if not self ._event_tar_close .wait (60 ):
54
+ warnings .warn ("Write process did not respond within timeout period. Last tar file may not have been closed properly." )
55
+ else :
56
+ self ._ft .close ()
57
+
58
+ def _launch_write_process (self ):
59
+ self ._q = mp .Queue (1 )
60
+ self ._event_done = mp .Event ()
61
+ self ._event_tar_close = mp .Event ()
62
+ p = mp .Process (target = self ._write_async )
63
+ p .start ()
64
+
65
+ def _write_async (self ):
66
+ self ._ft = self ._get_tar_file ()
67
+ try :
68
+ while True :
69
+ try :
70
+ sample = self ._q .get (block = False )
71
+ self ._add_sample (* sample )
72
+ continue
73
+ except queue .Empty :
74
+ pass
75
+ if self ._event_done .is_set () and self ._q .empty ():
76
+ self ._ft .close ()
77
+ self ._event_tar_close .set ()
78
+ return
79
+ except :
80
+ self ._ft .close ()
81
+ self ._event_tar_close .set ()
48
82
49
83
def _get_tar_file (self ):
50
84
file_path = self .base_path / f"{ self .base_name } _{ self .tar_count } .tar"
51
85
if os .path .exists (file_path ):
52
86
raise RuntimeError (f"Tar file already exists at `{ file_path } `" )
53
87
return tarfile .open (file_path , "w" , format = tarfile .GNU_FORMAT )
54
88
55
- def add_sample (self , X : list [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ] = None , comment_str : str = "" ):
56
- """
57
- Add a sample to the current tar file.
58
-
59
- Arguments:
60
- X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
61
- xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
62
- Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
63
- comment_str: Comment line (second line) to add to the xyz file.
64
- """
89
+ def _add_sample (self , X , xyzs , Y , comment_str ):
65
90
66
91
if self .sample_count >= self .max_count :
67
92
self .tar_count += 1
68
93
self .sample_count = 0
69
- self .ft .close ()
70
- self .ft = self ._get_tar_file ()
94
+ self ._ft .close ()
95
+ self ._ft = self ._get_tar_file ()
71
96
72
97
# Write AFM images
73
98
for i , x in enumerate (X ):
74
99
for j in range (x .shape [- 1 ]):
75
100
xj = x [:, :, j ]
76
101
xj = ((xj - xj .min ()) / np .ptp (xj ) * (2 ** 8 - 1 )).astype (np .uint8 ) # Convert range to 0-255 integers
77
102
img_bytes = io .BytesIO ()
78
- Image .fromarray (xj .T [::- 1 ], mode = "L" ).save (img_bytes , "png" , compress_level = self . png_compress_level )
103
+ Image .fromarray (xj .T [::- 1 ], mode = "L" ).save (img_bytes , "png" )
79
104
img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
80
- self .ft .addfile (get_tarinfo (f"{ self .total_count } .{ j :02d} .{ i } .png" , img_bytes ), img_bytes )
105
+ self ._ft .addfile (get_tarinfo (f"{ self .total_count } .{ j :02d} .{ i } .png" , img_bytes ), img_bytes )
81
106
img_bytes .close ()
82
107
83
108
# Write xyz file
@@ -89,7 +114,7 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
89
114
xyz_bytes .write (bytearray (f"{ xyz [i ]:10.8f} \t " , "utf-8" ))
90
115
xyz_bytes .write (bytearray ("\n " , "utf-8" ))
91
116
xyz_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
92
- self .ft .addfile (get_tarinfo (f"{ self .total_count } .xyz" , xyz_bytes ), xyz_bytes )
117
+ self ._ft .addfile (get_tarinfo (f"{ self .total_count } .xyz" , xyz_bytes ), xyz_bytes )
93
118
xyz_bytes .close ()
94
119
95
120
# Write image descriptors (if any)
@@ -98,12 +123,27 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
98
123
img_bytes = io .BytesIO ()
99
124
np .save (img_bytes , y .astype (np .float32 ))
100
125
img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
101
- self .ft .addfile (get_tarinfo (f"{ self .total_count } .desc_{ i } .npy" , img_bytes ), img_bytes )
126
+ self ._ft .addfile (get_tarinfo (f"{ self .total_count } .desc_{ i } .npy" , img_bytes ), img_bytes )
102
127
img_bytes .close ()
103
128
104
129
self .sample_count += 1
105
130
self .total_count += 1
106
131
132
+ def add_sample (self , X : list [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ] = None , comment_str : str = "" ):
133
+ """
134
+ Add a sample to the current tar file.
135
+
136
+ Arguments:
137
+ X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
138
+ xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
139
+ Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
140
+ comment_str: Comment line (second line) to add to the xyz file.
141
+ """
142
+ if self .async_write :
143
+ self ._q .put ((X , xyzs , Y , comment_str ), block = True , timeout = 60 )
144
+ else :
145
+ self ._add_sample (X , xyzs , Y , comment_str )
146
+
107
147
108
148
def get_tarinfo (fname : str , file_bytes : io .BytesIO ):
109
149
info = tarfile .TarInfo (fname )
@@ -128,13 +168,12 @@ class TarSampleList(TypedDict, total=False):
128
168
129
169
class TarDataGenerator :
130
170
"""
131
- Iterable that loads data from tar archives with data saved in npz format for generating samples
132
- with the GeneratorAFMTrainer in ppafm.
171
+ Iterable that loads data from tar archives with data saved in npz format for generating samples with ``GeneratorAFMTrainer``
172
+ in * ppafm* .
133
173
134
174
The npz files should contain the following entries:
135
175
136
- - ``'data'``: An array containing the potential/density on a 3D grid. The potential is assumed to be in
137
- units of eV and density in units of e/Å^3.
176
+ - ``'data'``: An array containing the potential/density on a 3D grid.
138
177
- ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``.
139
178
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
140
179
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
@@ -148,8 +187,9 @@ class TarDataGenerator:
148
187
- ``'rho_sample'``: Sample electron density if the sample dict contained ``rho``, or ``None`` otherwise.
149
188
- ``'rot'``: Rotation matrix.
150
189
151
- Note: it is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
152
- Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
190
+ Note:
191
+ It is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
192
+ Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
153
193
154
194
Arguments:
155
195
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.
0 commit comments