Skip to content

Commit 9af2380

Browse files
committed
Tar generator timings.
1 parent d3f4740 commit 9af2380

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

mlspm/data_generation.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
206206
proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:]
207207
print(f"Starting worker {i_proc}, id {proc_id}")
208208

209-
if self._timings:
210-
start_time = time.perf_counter()
211-
n_sample_tot = 0
209+
start_time = time.perf_counter()
210+
total_bytes = 0
211+
n_sample_total = 0
212212

213213
for sample_list in sample_lists:
214214

@@ -240,9 +240,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
240240
rots = sample_list["rots"][i_sample]
241241
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample])
242242
pot *= -1 # Unit conversion, eV -> V
243+
total_bytes += pot.nbytes
243244
if use_rho:
244245
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample])
245246
rho_shape = rho.shape
247+
total_bytes += rho.nbytes
246248
else:
247249
lvec_rho = None
248250
rho_shape = None
@@ -272,7 +274,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
272274

273275
if self._timings:
274276
t3 = time.perf_counter()
275-
n_sample_tot += 1
277+
n_sample_total += 1
276278
print(
277279
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Wait-unlink: "
278280
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} "
@@ -290,7 +292,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
290292

291293
if self._timings:
292294
dt = time.perf_counter() - start_time
293-
print(f"[Worker {i_proc}]: Loaded {n_sample_tot} samples in {dt}s. Average load time: {dt / n_sample_tot}s.")
295+
print(
296+
f"[Worker {i_proc}]: Loaded {n_sample_total} samples in {dt}s, totaling {total_bytes / 2**30:.3f}GiB. "
297+
f"Average load time: {dt / n_sample_total}s."
298+
)
294299

295300
def _get_queue_sample(self):
296301

@@ -331,19 +336,24 @@ def _get_queue_sample(self):
331336

332337
def _yield_samples(self):
333338

334-
for _ in range(len(self)):
339+
start_time = time.perf_counter()
340+
n_sample_yielded = 0
341+
342+
n_sample_total = sum([len(sample_list["rots"]) for sample_list in self.samples])
343+
344+
for _ in range(n_sample_total):
335345

336346
if self._timings:
337347
t0 = time.perf_counter()
338348

339349
i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id = self._get_queue_sample()
340-
341350
if self._timings:
342351
t1 = time.perf_counter()
343352

344353
for rot in rots:
345354
sample_dict = {"xyzs": xyzs, "Zs": Zs, "qs": pot, "rho_sample": rho, "rot": rot}
346355
yield sample_dict
356+
n_sample_yielded += 1
347357

348358
if self._timings:
349359
t2 = time.perf_counter()
@@ -358,6 +368,10 @@ def _yield_samples(self):
358368
t3 = time.perf_counter()
359369
print(f"[Main, id {sample_id}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")
360370

371+
if self._timings:
372+
dt = time.perf_counter() - start_time
373+
print(f"[Main]: Yielded {n_sample_yielded} samples in {dt}s. Average yield time: {dt / n_sample_yielded}s.")
374+
361375

362376
def _put_to_shared_memory(array, name):
363377
shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes, name=name)

0 commit comments

Comments
 (0)