1
-
2
1
import io
2
+ import multiprocessing as mp
3
+ import multiprocessing .shared_memory
3
4
import os
4
5
import tarfile
5
6
import time
6
7
from os import PathLike
7
8
from pathlib import Path
8
- from typing import List , Optional
9
+ from typing import Optional , TypedDict , NotRequired
9
10
10
11
import numpy as np
11
12
from PIL import Image
12
13
13
-
14
14
class TarWriter :
15
- '''
15
+ """
16
16
Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
17
17
:meth:`add_sample`.
18
18
@@ -25,10 +25,10 @@ class TarWriter:
25
25
base_name: Base name for output tar files. The number of the tar file is appended to the name.
26
26
max_count: Maximum number of samples per tar file.
27
27
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
28
- write speed.
29
- '''
28
+ write speed.
29
+ """
30
30
31
- def __init__ (self , base_path : PathLike = './' , base_name : str = '' , max_count : int = 100 , png_compress_level = 4 ):
31
+ def __init__ (self , base_path : PathLike = "./" , base_name : str = "" , max_count : int = 100 , png_compress_level = 4 ):
32
32
self .base_path = Path (base_path )
33
33
self .base_name = base_name
34
34
self .max_count = max_count
@@ -45,12 +45,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
45
45
self .ft .close ()
46
46
47
47
def _get_tar_file (self ):
48
- file_path = self .base_path / f' { self .base_name } _{ self .tar_count } .tar'
48
+ file_path = self .base_path / f" { self .base_name } _{ self .tar_count } .tar"
49
49
if os .path .exists (file_path ):
50
- raise RuntimeError (f' Tar file already exists at `{ file_path } `' )
51
- return tarfile .open (file_path , 'w' , format = tarfile .GNU_FORMAT )
50
+ raise RuntimeError (f" Tar file already exists at `{ file_path } `" )
51
+ return tarfile .open (file_path , "w" , format = tarfile .GNU_FORMAT )
52
52
53
- def add_sample (self , X : List [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ]= None , comment_str : str = '' ):
53
+ def add_sample (self , X : list [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ] = None , comment_str : str = "" ):
54
54
"""
55
55
Add a sample to the current tar file.
56
56
@@ -71,39 +71,225 @@ def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
71
71
for i , x in enumerate (X ):
72
72
for j in range (x .shape [- 1 ]):
73
73
xj = x [:, :, j ]
74
- xj = ((xj - xj .min ()) / np .ptp (xj ) * (2 ** 8 - 1 )).astype (np .uint8 ) # Convert range to 0-255 integers
74
+ xj = ((xj - xj .min ()) / np .ptp (xj ) * (2 ** 8 - 1 )).astype (np .uint8 ) # Convert range to 0-255 integers
75
75
img_bytes = io .BytesIO ()
76
- Image .fromarray (xj .T [::- 1 ], mode = 'L' ).save (img_bytes , ' png' , compress_level = self .png_compress_level )
77
- img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
78
- self .ft .addfile (get_tarinfo (f' { self .total_count } .{ j :02d} .{ i } .png' , img_bytes ), img_bytes )
76
+ Image .fromarray (xj .T [::- 1 ], mode = "L" ).save (img_bytes , " png" , compress_level = self .png_compress_level )
77
+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
78
+ self .ft .addfile (get_tarinfo (f" { self .total_count } .{ j :02d} .{ i } .png" , img_bytes ), img_bytes )
79
79
img_bytes .close ()
80
-
80
+
81
81
# Write xyz file
82
82
xyz_bytes = io .BytesIO ()
83
- xyz_bytes .write (bytearray (f' { len (xyzs )} \n { comment_str } \n ' , ' utf-8' ))
83
+ xyz_bytes .write (bytearray (f" { len (xyzs )} \n { comment_str } \n " , " utf-8" ))
84
84
for xyz in xyzs :
85
- xyz_bytes .write (bytearray (f' { int (xyz [- 1 ])} \t ' , ' utf-8' ))
86
- for i in range (len (xyz )- 1 ):
87
- xyz_bytes .write (bytearray (f' { xyz [i ]:10.8f} \t ' , ' utf-8' ))
88
- xyz_bytes .write (bytearray (' \n ' , ' utf-8' ))
89
- xyz_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
90
- self .ft .addfile (get_tarinfo (f' { self .total_count } .xyz' , xyz_bytes ), xyz_bytes )
85
+ xyz_bytes .write (bytearray (f" { int (xyz [- 1 ])} \t " , " utf-8" ))
86
+ for i in range (len (xyz ) - 1 ):
87
+ xyz_bytes .write (bytearray (f" { xyz [i ]:10.8f} \t " , " utf-8" ))
88
+ xyz_bytes .write (bytearray (" \n " , " utf-8" ))
89
+ xyz_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
90
+ self .ft .addfile (get_tarinfo (f" { self .total_count } .xyz" , xyz_bytes ), xyz_bytes )
91
91
xyz_bytes .close ()
92
92
93
93
# Write image descriptors (if any)
94
94
if Y is not None :
95
95
for i , y in enumerate (Y ):
96
96
img_bytes = io .BytesIO ()
97
97
np .save (img_bytes , y .astype (np .float32 ))
98
- img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
99
- self .ft .addfile (get_tarinfo (f' { self .total_count } .desc_{ i } .npy' , img_bytes ), img_bytes )
98
+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
99
+ self .ft .addfile (get_tarinfo (f" { self .total_count } .desc_{ i } .npy" , img_bytes ), img_bytes )
100
100
img_bytes .close ()
101
101
102
102
self .sample_count += 1
103
103
self .total_count += 1
104
104
105
+
105
106
def get_tarinfo (fname : str , file_bytes : io .BytesIO ):
106
107
info = tarfile .TarInfo (fname )
107
108
info .size = file_bytes .getbuffer ().nbytes
108
109
info .mtime = time .time ()
109
- return info
110
+ return info
111
+
112
+ class TarSample (TypedDict ):
113
+ """
114
+ - ``'hartree'``: Path to the Hartree potential. First item in the tuple is the path
115
+ to the tar file relative to ``base_path``, and second entry is the tar file member name.
116
+ - ``'rho'``: (Optional) Path to the electron density. First item in the tuple is the path
117
+ to the tar file relative to ``base_path``, and second entry is the tar file member name.
118
+ - ``'rots'``: List of rotations to generate for the sample.
119
+ """
120
+ hartree : tuple [str , str ]
121
+ rho : NotRequired [tuple [str , str ]]
122
+ rots : list [np .ndarray ]
123
+
124
+ class TarDataGenerator :
125
+ """
126
+ Iterable that loads data from tar archives with data saved in npz format for generating samples
127
+ with the GeneratorAFMTrainer in ppafm.
128
+
129
+ The npz files should contain the following entries:
130
+
131
+ - ``'array'``: An array containing the potential/density on a 3D grid.
132
+ - ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``.
133
+ - ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
134
+ - ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
135
+ - ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``.
136
+
137
+ Arguments:
138
+ samples: List of sample dicts as :class:`TarSample`. If ``rho`` is present in the dict, the full-density-based model
139
+ is used in the simulation. Otherwise Lennard-Jones with Hartree electrostatics is used.
140
+ base_path: Path to the directory with the tar files.
141
+ n_proc: Number of parallel processes for loading data. The samples get divided evenly over the processes.
142
+ """
143
+
144
+ _timings = False
145
+
146
+ def __init__ (self , samples : list [TarSample ], base_path : PathLike = "./" , n_proc : int = 1 ):
147
+ self .samples = samples
148
+ self .base_path = Path (base_path )
149
+ self .n_proc = n_proc
150
+
151
+ def __len__ (self ):
152
+ """Total number of samples (including rotations)"""
153
+ return sum ([len (s ["rots" ]) for s in self .samples ])
154
+
155
+ def _launch_procs (self ):
156
+ self .q = mp .Queue (maxsize = self .n_proc )
157
+ self .events = []
158
+ samples_split = np .array_split (self .samples , self .n_proc )
159
+ for i in range (self .n_proc ):
160
+ event = mp .Event ()
161
+ p = mp .Process (target = self ._load_samples , args = (samples_split [i ], i , event ))
162
+ p .start ()
163
+ self .events .append (event )
164
+
165
+ def __iter__ (self ):
166
+ self ._launch_procs ()
167
+ self .iterator = iter (self ._yield_samples ())
168
+ return self
169
+
170
+ def __next__ (self ):
171
+ return next (self .iterator )
172
+
173
+ def _get_data (self , tar_path : PathLike , name : str ):
174
+ tar_path = self .base_path / tar_path
175
+ with tarfile .open (tar_path , "r" ) as f :
176
+ data = np .load (f .extractfile (name ))
177
+ array = data ["data" ]
178
+ origin = data ["origin" ]
179
+ lattice = data ["lattice" ]
180
+ xyzs = data ["xyz" ]
181
+ Zs = data ["Z" ]
182
+ lvec = np .concatenate ([origin [None , :], lattice ], axis = 0 )
183
+ return array , lvec , xyzs , Zs
184
+
185
+ def _load_samples (self , samples : list [TarSample ], i_proc : int , event : mp .Event ):
186
+
187
+ proc_id = str (time .time_ns () + 1000000000 * i_proc )[- 10 :]
188
+ print (f"Starting worker { i_proc } , id { proc_id } " )
189
+
190
+ for i , sample in enumerate (samples ):
191
+
192
+ if self ._timings :
193
+ t0 = time .perf_counter ()
194
+
195
+ # Load data from tar(s)
196
+ rots = sample ["rots" ]
197
+ hartree_tar_path , name = sample ["hartree" ]
198
+ pot , lvec , xyzs , Zs = self ._get_data (hartree_tar_path , name )
199
+ pot *= - 1 # Unit conversion, eV -> V
200
+ if "rho" in sample :
201
+ rho_tar_path , name = sample ["rho" ]
202
+ rho , _ , _ , _ = self ._get_data (rho_tar_path , name )
203
+
204
+ if self ._timings :
205
+ t1 = time .perf_counter ()
206
+
207
+ # Put the data to shared memory
208
+ sample_id_pot = f"{ i_proc } _{ proc_id } _{ i } _pot"
209
+ shm_pot = mp .shared_memory .SharedMemory (create = True , size = pot .nbytes , name = sample_id_pot )
210
+ b = np .ndarray (pot .shape , dtype = np .float32 , buffer = shm_pot .buf )
211
+ b [:] = pot [:]
212
+
213
+ if "rho" in sample :
214
+ sample_id_rho = f"{ i_proc } _{ proc_id } _{ i } __rho"
215
+ shm_rho = mp .shared_memory .SharedMemory (create = True , size = rho .nbytes , name = sample_id_rho )
216
+ b = np .ndarray (rho .shape , dtype = np .float32 , buffer = shm_rho .buf )
217
+ b [:] = rho [:]
218
+ rho_shape = rho .shape
219
+ else :
220
+ sample_id_rho = None
221
+ rho_shape = None
222
+
223
+ if self ._timings :
224
+ t2 = time .perf_counter ()
225
+
226
+ # Inform the main process of the data using the queue
227
+ self .q .put ((i_proc , sample_id_pot , sample_id_rho , pot .shape , rho_shape , lvec , xyzs , Zs , rots ))
228
+
229
+ if self ._timings :
230
+ t3 = time .perf_counter ()
231
+
232
+ # Wait until main process is done with the data
233
+ if not event .wait (timeout = 60 ):
234
+ raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
235
+ event .clear ()
236
+
237
+ if self ._timings :
238
+ t4 = time .perf_counter ()
239
+
240
+ # Done with shared memory
241
+ shm_pot .unlink ()
242
+ shm_pot .close ()
243
+ if "rho" in sample :
244
+ shm_rho .unlink ()
245
+ shm_rho .close ()
246
+
247
+ if self ._timings :
248
+ t5 = time .perf_counter ()
249
+ print (
250
+ f"[Worker { i_proc } , id { sample_id_pot } ] Get data / Shm / Queue / Wait / Unlink: "
251
+ f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} / { t4 - t3 :.5f} / { t5 - t4 :.5f} "
252
+ )
253
+
254
+ def _yield_samples (self ):
255
+
256
+ from ppafm .ocl .field import ElectronDensity , HartreePotential
257
+
258
+ for _ in range (len (self )):
259
+
260
+ if self ._timings :
261
+ t0 = time .perf_counter ()
262
+
263
+ # Get data info from the queue
264
+ i_proc , sample_id_pot , sample_id_rho , pot_shape , rho_shape , lvec , xyzs , Zs , rots = self .q .get (timeout = 200 )
265
+
266
+ # Get data from the shared memory
267
+ shm_pot = mp .shared_memory .SharedMemory (sample_id_pot )
268
+ pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
269
+ pot = HartreePotential (pot , lvec )
270
+ if sample_id_rho is not None :
271
+ shm_rho = mp .shared_memory .SharedMemory (sample_id_rho )
272
+ rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
273
+ rho = ElectronDensity (rho , lvec )
274
+ else :
275
+ rho = None
276
+
277
+ if self ._timings :
278
+ t1 = time .perf_counter ()
279
+
280
+ for rot in rots :
281
+ sample_dict = {"xyzs" : xyzs , "Zs" : Zs , "qs" : pot , "rho_sample" : rho , "rot" : rot }
282
+ yield sample_dict
283
+
284
+ if self ._timings :
285
+ t2 = time .perf_counter ()
286
+
287
+ # Close shared memory and inform producer that the shared memory can be unlinked
288
+ shm_pot .close ()
289
+ if sample_id_rho is not None :
290
+ shm_rho .close ()
291
+ self .events [i_proc ].set ()
292
+
293
+ if self ._timings :
294
+ t3 = time .perf_counter ()
295
+ print (f"[Main, id { sample_id_pot } ] Receive data / Yield / Event: " f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} " )
0 commit comments