Skip to content

Commit 171bd34

Browse files
committed
Optimization of the sample loading.
Moved waiting for release of shared memory until after the next sample is loaded, so that every process has something in the queue at every moment.
1 parent b2c69a8 commit 171bd34

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

mlspm/data_generation.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class TarDataGenerator:
139139
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
140140
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
141141
- ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``.
142-
142+
143143
Yields dicts that contain the following:
144144
145145
- ``'xyzs'``: Atom xyz coordinates.
@@ -155,6 +155,8 @@ class TarDataGenerator:
155155
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.
156156
base_path: Path to the directory with the tar files.
157157
n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
158+
For memory usage, note that a maximum number of samples double the number of processes can be loaded into
159+
memory at the same time.
158160
"""
159161

160162
_timings = False
@@ -169,7 +171,8 @@ def __len__(self) -> int:
169171
return sum([len(s["rots"]) for s in self.samples])
170172

171173
def _launch_procs(self):
172-
self.q = mp.Queue(maxsize=self.n_proc)
174+
queue_size = 2 * self.n_proc
175+
self.q = mp.Queue(queue_size)
173176
self.events = []
174177
samples_split = np.array_split(self.samples, self.n_proc)
175178
for i in range(self.n_proc):
@@ -201,6 +204,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
201204
proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:]
202205
print(f"Starting worker {i_proc}, id {proc_id}")
203206

207+
if self._timings:
208+
start_time = time.perf_counter()
209+
n_sample_tot = 0
210+
204211
for sample_list in sample_lists:
205212

206213
tar_path_hartree, name_list_hartree = sample_list["hartree"]
@@ -219,6 +226,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
219226
f"Inconsistent number of samples between hartree and rho lists ({len(name_list_rho)} != {n_sample})"
220227
)
221228

229+
shm_pot = shm_pot_prev = None
230+
shm_rho = shm_rho_prev = None
222231
for i_sample in range(n_sample):
223232

224233
if self._timings:
@@ -253,41 +262,55 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
253262
rho_shape = None
254263
lvec_rho = None
255264

256-
if self._timings:
257-
t2 = time.perf_counter()
258-
259265
# Inform the main process of the data using the queue
260266
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots))
261267

262268
if self._timings:
263-
t3 = time.perf_counter()
269+
t2 = time.perf_counter()
264270

265-
# Wait until main process is done with the data
266-
if not event.wait(timeout=60):
267-
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
268-
event.clear()
271+
if i_sample > 0:
269272

270-
if self._timings:
271-
t4 = time.perf_counter()
273+
# Wait until main process is done with the previous data
274+
if not event.wait(timeout=60):
275+
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
276+
event.clear()
272277

273-
# Done with shared memory
274-
shm_pot.close()
275-
shm_pot.unlink()
276-
if use_rho:
277-
shm_rho.close()
278-
shm_rho.unlink()
278+
# Done with shared memory
279+
shm_pot_prev.close()
280+
shm_pot_prev.unlink()
281+
if use_rho:
282+
shm_rho_prev.close()
283+
shm_rho_prev.unlink()
279284

280285
if self._timings:
281-
t5 = time.perf_counter()
286+
t3 = time.perf_counter()
287+
n_sample_tot += 1
282288
print(
283-
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: "
284-
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}"
289+
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Wait-unlink: "
290+
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} "
285291
)
286292

293+
shm_pot_prev = shm_pot
294+
shm_rho_prev = shm_rho
295+
296+
# Wait to unlink the last data
297+
if not event.wait(timeout=60):
298+
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
299+
event.clear()
300+
shm_pot.close()
301+
shm_pot.unlink()
302+
if use_rho:
303+
shm_rho.close()
304+
shm_rho.unlink()
305+
287306
tar_hartree.close()
288307
if use_rho:
289308
tar_rho.close()
290309

310+
if self._timings:
311+
dt = time.perf_counter() - start_time
312+
print(f"[Worker {i_proc}]: Loaded {n_sample_tot} samples in {dt}s. Average load time: {dt / n_sample_tot}s.")
313+
291314
def _yield_samples(self):
292315

293316
for _ in range(len(self)):

0 commit comments

Comments
 (0)