@@ -226,41 +226,37 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
226
226
f"Inconsistent number of samples between hartree and rho lists ({ len (name_list_rho )} != { n_sample } )"
227
227
)
228
228
229
- shm_pot = shm_pot_prev = None
230
- shm_rho = shm_rho_prev = None
229
+ shm_pot_prev = None
230
+ shm_rho_prev = None
231
+
231
232
for i_sample in range (n_sample ):
232
233
233
234
if self ._timings :
234
235
t0 = time .perf_counter ()
235
236
236
237
# Load data from tar(s)
237
238
rots = sample_list ["rots" ][i_sample ]
238
- name_hartree = name_list_hartree [i_sample ]
239
- pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_hartree )
239
+ pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_list_hartree [i_sample ])
240
240
pot *= - 1 # Unit conversion, eV -> V
241
241
if use_rho :
242
- name_rho = name_list_rho [i_sample ]
243
- rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_rho )
242
+ rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_list_rho [i_sample ])
243
+ rho_shape = rho .shape
244
+ else :
245
+ lvec_rho = None
246
+ rho_shape = None
244
247
245
248
if self ._timings :
246
249
t1 = time .perf_counter ()
247
250
248
251
# Put the data to shared memory
249
252
sample_id_pot = f"{ i_proc } _{ proc_id } _{ i_sample } _pot"
250
- shm_pot = mp .shared_memory .SharedMemory (create = True , size = pot .nbytes , name = sample_id_pot )
251
- b = np .ndarray (pot .shape , dtype = np .float32 , buffer = shm_pot .buf )
252
- b [:] = pot [:]
253
-
253
+ shm_pot = _put_to_shared_memory (pot , sample_id_pot )
254
254
if use_rho :
255
255
sample_id_rho = f"{ i_proc } _{ proc_id } _{ i_sample } _rho"
256
- shm_rho = mp .shared_memory .SharedMemory (create = True , size = rho .nbytes , name = sample_id_rho )
257
- b = np .ndarray (rho .shape , dtype = np .float32 , buffer = shm_rho .buf )
258
- b [:] = rho [:]
259
- rho_shape = rho .shape
256
+ shm_rho = _put_to_shared_memory (rho , sample_id_rho )
260
257
else :
261
258
sample_id_rho = None
262
- rho_shape = None
263
- lvec_rho = None
259
+ shm_rho = None
264
260
265
261
# Inform the main process of the data using the queue
266
262
self .q .put ((i_proc , sample_id_pot , sample_id_rho , pot .shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots ))
@@ -269,18 +265,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
269
265
t2 = time .perf_counter ()
270
266
271
267
if i_sample > 0 :
272
-
273
268
# Wait until main process is done with the previous data
274
- if not event .wait (timeout = 60 ):
275
- raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
276
- event .clear ()
277
-
278
- # Done with shared memory
279
- shm_pot_prev .close ()
280
- shm_pot_prev .unlink ()
281
- if use_rho :
282
- shm_rho_prev .close ()
283
- shm_rho_prev .unlink ()
269
+ _wait_and_unlink (i_proc , event , shm_pot_prev , shm_rho_prev )
284
270
285
271
if self ._timings :
286
272
t3 = time .perf_counter ()
@@ -294,14 +280,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
294
280
shm_rho_prev = shm_rho
295
281
296
282
# Wait to unlink the last data
297
- if not event .wait (timeout = 60 ):
298
- raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
299
- event .clear ()
300
- shm_pot .close ()
301
- shm_pot .unlink ()
302
- if use_rho :
303
- shm_rho .close ()
304
- shm_rho .unlink ()
283
+ _wait_and_unlink (i_proc , event , shm_pot , shm_rho )
305
284
306
285
tar_hartree .close ()
307
286
if use_rho :
@@ -311,26 +290,35 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
311
290
dt = time .perf_counter () - start_time
312
291
print (f"[Worker { i_proc } ]: Loaded { n_sample_tot } samples in { dt } s. Average load time: { dt / n_sample_tot } s." )
313
292
293
+ def _get_queue_sample (self ):
294
+
295
+ i_proc , sample_id_pot , sample_id_rho , pot_shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots = self .q .get (timeout = 200 )
296
+
297
+ shm_pot = mp .shared_memory .SharedMemory (sample_id_pot )
298
+ pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
299
+ pot = HartreePotential (pot , lvec_pot )
300
+ # This starts a copy to the OpenCL device. Better to start it here so that buffer preparation is instant during the simulation.
301
+ pot .cl_array
302
+
303
+ if sample_id_rho is not None :
304
+ shm_rho = mp .shared_memory .SharedMemory (sample_id_rho )
305
+ rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
306
+ rho = ElectronDensity (rho , lvec_rho )
307
+ rho .cl_array
308
+ else :
309
+ shm_rho = None
310
+ rho = None
311
+
312
+ return i_proc , xyzs , Zs , rots , pot , shm_pot , rho , shm_rho , sample_id_pot
313
+
314
314
def _yield_samples (self ):
315
315
316
316
for _ in range (len (self )):
317
317
318
318
if self ._timings :
319
319
t0 = time .perf_counter ()
320
320
321
- # Get data info from the queue
322
- i_proc , sample_id_pot , sample_id_rho , pot_shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots = self .q .get (timeout = 200 )
323
-
324
- # Get data from the shared memory
325
- shm_pot = mp .shared_memory .SharedMemory (sample_id_pot )
326
- pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
327
- pot = HartreePotential (pot , lvec_pot )
328
- if sample_id_rho is not None :
329
- shm_rho = mp .shared_memory .SharedMemory (sample_id_rho )
330
- rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
331
- rho = ElectronDensity (rho , lvec_rho )
332
- else :
333
- rho = None
321
+ i_proc , xyzs , Zs , rots , pot , shm_pot , rho , shm_rho , sample_id = self ._get_queue_sample ()
334
322
335
323
if self ._timings :
336
324
t1 = time .perf_counter ()
@@ -344,10 +332,28 @@ def _yield_samples(self):
344
332
345
333
# Close shared memory and inform producer that the shared memory can be unlinked
346
334
shm_pot .close ()
347
- if sample_id_rho is not None :
335
+ if shm_rho is not None :
348
336
shm_rho .close ()
349
337
self .events [i_proc ].set ()
350
338
351
339
if self ._timings :
352
340
t3 = time .perf_counter ()
353
- print (f"[Main, id { sample_id_pot } ] Receive data / Yield / Event: " f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} " )
341
+ print (f"[Main, id { sample_id } ] Receive data / Yield / Event: " f"{ t1 - t0 :.5f} / { t2 - t1 :.5f} / { t3 - t2 :.5f} " )
342
+
343
+
344
+ def _put_to_shared_memory (array , name ):
345
+ shm = mp .shared_memory .SharedMemory (create = True , size = array .nbytes , name = name )
346
+ b = np .ndarray (array .shape , dtype = np .float32 , buffer = shm .buf )
347
+ b [:] = array [:]
348
+ return shm
349
+
350
+
351
+ def _wait_and_unlink (i_proc , event , shm_pot , shm_rho ):
352
+ if not event .wait (timeout = 60 ):
353
+ raise RuntimeError (f"[Worker { i_proc } ]: Did not receive signal from main process in 60 seconds." )
354
+ event .clear ()
355
+ shm_pot .close ()
356
+ shm_pot .unlink ()
357
+ if shm_rho :
358
+ shm_rho .close ()
359
+ shm_rho .unlink ()
0 commit comments