Skip to content

Commit fd88880

Browse files
committed
Update potential/density buffers without reallocation.
1 parent a64a559 commit fd88880

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

mlspm/data_generation.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_p
165165
self.samples = samples
166166
self.base_path = Path(base_path)
167167
self.n_proc = n_proc
168+
self.pot = None
169+
self.rho = None
168170

169171
def __len__(self) -> int:
170172
"""Total number of samples (including rotations)"""
@@ -292,24 +294,40 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
292294

293295
def _get_queue_sample(self):
294296

297+
if self._timings:
298+
t0 = time.perf_counter()
299+
295300
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)
296301

302+
if self._timings:
303+
t1 = time.perf_counter()
304+
297305
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
298306
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
299-
pot = HartreePotential(pot, lvec_pot)
300-
# This starts a copy to the OpenCL device. Better to start it here so that buffer preparation is instant during the simulation.
301-
pot.cl_array
307+
if self.pot is None:
308+
self.pot = HartreePotential(pot, lvec_pot)
309+
else:
310+
self.pot.update_array(pot, lvec_pot)
311+
312+
if self._timings:
313+
t2 = time.perf_counter()
302314

303315
if sample_id_rho is not None:
304316
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
305317
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
306-
rho = ElectronDensity(rho, lvec_rho)
307-
rho.cl_array
318+
if self.rho is None:
319+
self.rho = ElectronDensity(rho, lvec_rho)
320+
else:
321+
self.rho.update_array(pot, lvec_pot)
308322
else:
309323
shm_rho = None
310324
rho = None
311325

312-
return i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id_pot
326+
if self._timings:
327+
t3 = time.perf_counter()
328+
print(f"[Main, receive data, id {sample_id_pot}] Queue / Pot / Rho: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")
329+
330+
return i_proc, xyzs, Zs, rots, self.pot, shm_pot, self.rho, shm_rho, sample_id_pot
313331

314332
def _yield_samples(self):
315333

0 commit comments

Comments
 (0)