@@ -206,9 +206,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
206
206
proc_id = str (time .time_ns () + 1000000000 * i_proc )[- 10 :]
207
207
print (f"Starting worker { i_proc } , id { proc_id } " )
208
208
209
- if self . _timings :
210
- start_time = time . perf_counter ()
211
- n_sample_tot = 0
209
+ start_time = time . perf_counter ()
210
+ total_bytes = 0
211
+ n_sample_total = 0
212
212
213
213
for sample_list in sample_lists :
214
214
@@ -240,9 +240,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
240
240
rots = sample_list ["rots" ][i_sample ]
241
241
pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_list_hartree [i_sample ])
242
242
pot *= - 1 # Unit conversion, eV -> V
243
+ total_bytes += pot .nbytes
243
244
if use_rho :
244
245
rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_list_rho [i_sample ])
245
246
rho_shape = rho .shape
247
+ total_bytes += rho .nbytes
246
248
else :
247
249
lvec_rho = None
248
250
rho_shape = None
@@ -272,7 +274,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
272
274
273
275
if self ._timings :
274
276
t3 = time .perf_counter ()
275
- n_sample_tot += 1
277
+ n_sample_total += 1
276
278
print (
277
279
f"[Worker { i_proc } , id { sample_id_pot } ] Get data / Shm / Wait-unlink: "
278
280
f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} "
@@ -290,7 +292,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
290
292
291
293
if self ._timings :
292
294
dt = time .perf_counter () - start_time
293
- print (f"[Worker { i_proc } ]: Loaded { n_sample_tot } samples in { dt } s. Average load time: { dt / n_sample_tot } s." )
295
+ print (
296
+ f"[Worker { i_proc } ]: Loaded { n_sample_total } samples in { dt } s, totaling { total_bytes / 2 ** 30 :.3f} GiB. "
297
+ f"Average load time: { dt / n_sample_total } s."
298
+ )
294
299
295
300
def _get_queue_sample (self ):
296
301
@@ -331,19 +336,24 @@ def _get_queue_sample(self):
331
336
332
337
def _yield_samples (self ):
333
338
334
- for _ in range (len (self )):
339
+ start_time = time .perf_counter ()
340
+ n_sample_yielded = 0
341
+
342
+ n_sample_total = sum ([len (sample_list ["rots" ]) for sample_list in self .samples ])
343
+
344
+ for _ in range (n_sample_total ):
335
345
336
346
if self ._timings :
337
347
t0 = time .perf_counter ()
338
348
339
349
i_proc , xyzs , Zs , rots , pot , shm_pot , rho , shm_rho , sample_id = self ._get_queue_sample ()
340
-
341
350
if self ._timings :
342
351
t1 = time .perf_counter ()
343
352
344
353
for rot in rots :
345
354
sample_dict = {"xyzs" : xyzs , "Zs" : Zs , "qs" : pot , "rho_sample" : rho , "rot" : rot }
346
355
yield sample_dict
356
+ n_sample_yielded += 1
347
357
348
358
if self ._timings :
349
359
t2 = time .perf_counter ()
@@ -358,6 +368,10 @@ def _yield_samples(self):
358
368
t3 = time .perf_counter ()
359
369
print (f"[Main, id { sample_id } ] Receive data / Yield / Event: " f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} " )
360
370
371
+ if self ._timings :
372
+ dt = time .perf_counter () - start_time
373
+ print (f"[Main]: Yielded { n_sample_yielded } samples in { dt } s. Average yield time: { dt / n_sample_yielded } s." )
374
+
361
375
362
376
def _put_to_shared_memory (array , name ):
363
377
shm = mp .shared_memory .SharedMemory (create = True , size = array .nbytes , name = name )
0 commit comments