Skip to content

Commit a64a559

Browse files
committed
Initiate OpenCL buffer copy before yield. Code reorganization.
1 parent 171bd34 commit a64a559

File tree

1 file changed

+56
-50
lines changed

1 file changed

+56
-50
lines changed

mlspm/data_generation.py

+56-50
Original file line numberDiff line numberDiff line change
@@ -226,41 +226,37 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
226226
f"Inconsistent number of samples between hartree and rho lists ({len(name_list_rho)} != {n_sample})"
227227
)
228228

229-
shm_pot = shm_pot_prev = None
230-
shm_rho = shm_rho_prev = None
229+
shm_pot_prev = None
230+
shm_rho_prev = None
231+
231232
for i_sample in range(n_sample):
232233

233234
if self._timings:
234235
t0 = time.perf_counter()
235236

236237
# Load data from tar(s)
237238
rots = sample_list["rots"][i_sample]
238-
name_hartree = name_list_hartree[i_sample]
239-
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
239+
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample])
240240
pot *= -1 # Unit conversion, eV -> V
241241
if use_rho:
242-
name_rho = name_list_rho[i_sample]
243-
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_rho)
242+
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample])
243+
rho_shape = rho.shape
244+
else:
245+
lvec_rho = None
246+
rho_shape = None
244247

245248
if self._timings:
246249
t1 = time.perf_counter()
247250

248251
# Put the data to shared memory
249252
sample_id_pot = f"{i_proc}_{proc_id}_{i_sample}_pot"
250-
shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot)
251-
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
252-
b[:] = pot[:]
253-
253+
shm_pot = _put_to_shared_memory(pot, sample_id_pot)
254254
if use_rho:
255255
sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho"
256-
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
257-
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
258-
b[:] = rho[:]
259-
rho_shape = rho.shape
256+
shm_rho = _put_to_shared_memory(rho, sample_id_rho)
260257
else:
261258
sample_id_rho = None
262-
rho_shape = None
263-
lvec_rho = None
259+
shm_rho = None
264260

265261
# Inform the main process of the data using the queue
266262
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots))
@@ -269,18 +265,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
269265
t2 = time.perf_counter()
270266

271267
if i_sample > 0:
272-
273268
# Wait until main process is done with the previous data
274-
if not event.wait(timeout=60):
275-
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
276-
event.clear()
277-
278-
# Done with shared memory
279-
shm_pot_prev.close()
280-
shm_pot_prev.unlink()
281-
if use_rho:
282-
shm_rho_prev.close()
283-
shm_rho_prev.unlink()
269+
_wait_and_unlink(i_proc, event, shm_pot_prev, shm_rho_prev)
284270

285271
if self._timings:
286272
t3 = time.perf_counter()
@@ -294,14 +280,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
294280
shm_rho_prev = shm_rho
295281

296282
# Wait to unlink the last data
297-
if not event.wait(timeout=60):
298-
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
299-
event.clear()
300-
shm_pot.close()
301-
shm_pot.unlink()
302-
if use_rho:
303-
shm_rho.close()
304-
shm_rho.unlink()
283+
_wait_and_unlink(i_proc, event, shm_pot, shm_rho)
305284

306285
tar_hartree.close()
307286
if use_rho:
@@ -311,26 +290,35 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
311290
dt = time.perf_counter() - start_time
312291
print(f"[Worker {i_proc}]: Loaded {n_sample_tot} samples in {dt}s. Average load time: {dt / n_sample_tot}s.")
313292

293+
def _get_queue_sample(self):
294+
295+
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)
296+
297+
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
298+
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
299+
pot = HartreePotential(pot, lvec_pot)
300+
# This starts a copy to the OpenCL device. Better to start it here so that buffer preparation is instant during the simulation.
301+
pot.cl_array
302+
303+
if sample_id_rho is not None:
304+
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
305+
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
306+
rho = ElectronDensity(rho, lvec_rho)
307+
rho.cl_array
308+
else:
309+
shm_rho = None
310+
rho = None
311+
312+
return i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id_pot
313+
314314
def _yield_samples(self):
315315

316316
for _ in range(len(self)):
317317

318318
if self._timings:
319319
t0 = time.perf_counter()
320320

321-
# Get data info from the queue
322-
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)
323-
324-
# Get data from the shared memory
325-
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
326-
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
327-
pot = HartreePotential(pot, lvec_pot)
328-
if sample_id_rho is not None:
329-
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
330-
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
331-
rho = ElectronDensity(rho, lvec_rho)
332-
else:
333-
rho = None
321+
i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id = self._get_queue_sample()
334322

335323
if self._timings:
336324
t1 = time.perf_counter()
@@ -344,10 +332,28 @@ def _yield_samples(self):
344332

345333
# Close shared memory and inform producer that the shared memory can be unlinked
346334
shm_pot.close()
347-
if sample_id_rho is not None:
335+
if shm_rho is not None:
348336
shm_rho.close()
349337
self.events[i_proc].set()
350338

351339
if self._timings:
352340
t3 = time.perf_counter()
353-
print(f"[Main, id {sample_id_pot}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")
341+
print(f"[Main, id {sample_id}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")
342+
343+
344+
def _put_to_shared_memory(array, name):
345+
shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes, name=name)
346+
b = np.ndarray(array.shape, dtype=np.float32, buffer=shm.buf)
347+
b[:] = array[:]
348+
return shm
349+
350+
351+
def _wait_and_unlink(i_proc, event, shm_pot, shm_rho):
352+
if not event.wait(timeout=60):
353+
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
354+
event.clear()
355+
shm_pot.close()
356+
shm_pot.unlink()
357+
if shm_rho:
358+
shm_rho.close()
359+
shm_rho.unlink()

0 commit comments

Comments
 (0)