@@ -139,7 +139,7 @@ class TarDataGenerator:
139
139
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
140
140
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
141
141
- ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``.
142
-
142
+
143
143
Yields dicts that contain the following:
144
144
145
145
- ``'xyzs'``: Atom xyz coordinates.
@@ -155,6 +155,8 @@ class TarDataGenerator:
155
155
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.
156
156
base_path: Path to the directory with the tar files.
157
157
n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
158
+ For memory usage, note that a maximum number of samples double the number of processes can be loaded into
159
+ memory at the same time.
158
160
"""
159
161
160
162
_timings = False
@@ -169,7 +171,8 @@ def __len__(self) -> int:
169
171
return sum ([len (s ["rots" ]) for s in self .samples ])
170
172
171
173
def _launch_procs (self ):
172
- self .q = mp .Queue (maxsize = self .n_proc )
174
+ queue_size = 2 * self .n_proc
175
+ self .q = mp .Queue (queue_size )
173
176
self .events = []
174
177
samples_split = np .array_split (self .samples , self .n_proc )
175
178
for i in range (self .n_proc ):
@@ -201,6 +204,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
201
204
proc_id = str (time .time_ns () + 1000000000 * i_proc )[- 10 :]
202
205
print (f"Starting worker { i_proc } , id { proc_id } " )
203
206
207
+ if self ._timings :
208
+ start_time = time .perf_counter ()
209
+ n_sample_tot = 0
210
+
204
211
for sample_list in sample_lists :
205
212
206
213
tar_path_hartree , name_list_hartree = sample_list ["hartree" ]
@@ -219,6 +226,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
219
226
f"Inconsistent number of samples between hartree and rho lists ({ len (name_list_rho )} != { n_sample } )"
220
227
)
221
228
229
+ shm_pot = shm_pot_prev = None
230
+ shm_rho = shm_rho_prev = None
222
231
for i_sample in range (n_sample ):
223
232
224
233
if self ._timings :
@@ -253,41 +262,55 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
253
262
rho_shape = None
254
263
lvec_rho = None
255
264
256
- if self ._timings :
257
- t2 = time .perf_counter ()
258
-
259
265
# Inform the main process of the data using the queue
260
266
self .q .put ((i_proc , sample_id_pot , sample_id_rho , pot .shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots ))
261
267
262
268
if self ._timings :
263
- t3 = time .perf_counter ()
269
+ t2 = time .perf_counter ()
264
270
265
- # Wait until main process is done with the data
266
- if not event .wait (timeout = 60 ):
267
- raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
268
- event .clear ()
271
+ if i_sample > 0 :
269
272
270
- if self ._timings :
271
- t4 = time .perf_counter ()
273
+ # Wait until main process is done with the previous data
274
+ if not event .wait (timeout = 60 ):
275
+ raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
276
+ event .clear ()
272
277
273
- # Done with shared memory
274
- shm_pot .close ()
275
- shm_pot .unlink ()
276
- if use_rho :
277
- shm_rho .close ()
278
- shm_rho .unlink ()
278
+ # Done with shared memory
279
+ shm_pot_prev .close ()
280
+ shm_pot_prev .unlink ()
281
+ if use_rho :
282
+ shm_rho_prev .close ()
283
+ shm_rho_prev .unlink ()
279
284
280
285
if self ._timings :
281
- t5 = time .perf_counter ()
286
+ t3 = time .perf_counter ()
287
+ n_sample_tot += 1
282
288
print (
283
- f"[Worker { i_proc } , id { sample_id_pot } ] Get data / Shm / Queue / Wait / Unlink : "
284
- f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} / { t4 - t3 :.5f } / { t5 - t4 :.5f } "
289
+ f"[Worker { i_proc } , id { sample_id_pot } ] Get data / Shm / Wait-unlink : "
290
+ f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} "
285
291
)
286
292
293
+ shm_pot_prev = shm_pot
294
+ shm_rho_prev = shm_rho
295
+
296
+ # Wait to unlink the last data
297
+ if not event .wait (timeout = 60 ):
298
+ raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
299
+ event .clear ()
300
+ shm_pot .close ()
301
+ shm_pot .unlink ()
302
+ if use_rho :
303
+ shm_rho .close ()
304
+ shm_rho .unlink ()
305
+
287
306
tar_hartree .close ()
288
307
if use_rho :
289
308
tar_rho .close ()
290
309
310
+ if self ._timings :
311
+ dt = time .perf_counter () - start_time
312
+ print (f"[Worker { i_proc } ]: Loaded { n_sample_tot } samples in { dt } s. Average load time: { dt / n_sample_tot } s." )
313
+
291
314
def _yield_samples (self ):
292
315
293
316
for _ in range (len (self )):
0 commit comments