From 4f979609dd975228b12969b6705df986064b6d2d Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Mon, 11 Mar 2024 20:21:57 +0200 Subject: [PATCH] Changed TarDataGenerator sample format. The samples are now divided by the tar file so that the access pattern is sequential. --- mlspm/data_generation.py | 199 ++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 89 deletions(-) diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py index d9135e9..c1a910b 100644 --- a/mlspm/data_generation.py +++ b/mlspm/data_generation.py @@ -10,6 +10,8 @@ import numpy as np from PIL import Image +from ppafm.ocl.field import ElectronDensity, HartreePotential + class TarWriter: """ @@ -109,18 +111,21 @@ def get_tarinfo(fname: str, file_bytes: io.BytesIO): info.mtime = time.time() return info -class TarSample(TypedDict, total=False): + +class TarSampleList(TypedDict, total=False): """ - - ``'hartree'``: Path to the Hartree potential. First item in the tuple is the path - to the tar file relative to ``base_path``, and second entry is the tar file member name. - - ``'rho'``: (Optional) Path to the electron density. First item in the tuple is the path - to the tar file relative to ``base_path``, and second entry is the tar file member name. - - ``'rots'``: List of rotations to generate for the sample. + - ``'hartree'``: Paths to the Hartree potentials. First item in the tuple is the path to the tar file, + and second entry is a list of tar file member names. + - ``'rho'``: (Optional) Paths to the electron densities. First item in the tuple is the path to the tar + file, and second entry is a list tar file member names. + - ``'rots'``: List of rotations for each sample. """ - hartree: tuple[str, str] - rho: tuple[str, str] + + hartree: tuple[PathLike, list[str]] + rho: tuple[PathLike, list[str]] rots: list[np.ndarray] + class TarDataGenerator: """ Iterable that loads data from tar archives with data saved in npz format for generating samples @@ -135,20 +140,19 @@ class TarDataGenerator: - ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``. Arguments: - samples: List of sample dicts as :class:`TarSample`. If ``rho`` is present in the dict, the full-density-based model - is used in the simulation. Otherwise Lennard-Jones with Hartree electrostatics is used. + samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``. base_path: Path to the directory with the tar files. - n_proc: Number of parallel processes for loading data. The samples get divided evenly over the processes. + n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes. """ _timings = False - def __init__(self, samples: list[TarSample], base_path: PathLike = "./", n_proc: int = 1): + def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_proc: int = 1): self.samples = samples self.base_path = Path(base_path) self.n_proc = n_proc - def __len__(self): + def __len__(self) -> int: """Total number of samples (including rotations)""" return sum([len(s["rots"]) for s in self.samples]) @@ -170,91 +174,108 @@ def __iter__(self): def __next__(self): return next(self.iterator) - def _get_data(self, tar_path: PathLike, name: str): - tar_path = self.base_path / tar_path - with tarfile.open(tar_path, "r") as f: - data = np.load(f.extractfile(name)) - array = data["data"] - origin = data["origin"] - lattice = data["lattice"] - xyzs = data["xyz"] - Zs = data["Z"] - lvec = np.concatenate([origin[None, :], lattice], axis=0) + def _get_data(self, tar: tarfile.TarFile, name: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + data = np.load(tar.extractfile(name)) + array = data["data"] + origin = data["origin"] + lattice = data["lattice"] + xyzs = data["xyz"] + Zs = data["Z"] + lvec = np.concatenate([origin[None, :], lattice], axis=0) return array, lvec, xyzs, Zs - def _load_samples(self, samples: list[TarSample], i_proc: int, event: mp.Event): + def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: mp.Event): proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:] print(f"Starting worker {i_proc}, id {proc_id}") - for i, sample in enumerate(samples): - - if self._timings: - t0 = time.perf_counter() - - # Load data from tar(s) - rots = sample["rots"] - hartree_tar_path, name = sample["hartree"] - pot, lvec, xyzs, Zs = self._get_data(hartree_tar_path, name) - pot *= -1 # Unit conversion, eV -> V - if "rho" in sample: - rho_tar_path, name = sample["rho"] - rho, _, _, _ = self._get_data(rho_tar_path, name) - - if self._timings: - t1 = time.perf_counter() - - # Put the data to shared memory - sample_id_pot = f"{i_proc}_{proc_id}_{i}_pot" - shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot) - b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf) - b[:] = pot[:] - - if "rho" in sample: - sample_id_rho = f"{i_proc}_{proc_id}_{i}__rho" - shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho) - b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf) - b[:] = rho[:] - rho_shape = rho.shape - else: - sample_id_rho = None - rho_shape = None - - if self._timings: - t2 = time.perf_counter() - - # Inform the main process of the data using the queue - self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots)) - - if self._timings: - t3 = time.perf_counter() - - # Wait until main process is done with the data - if not event.wait(timeout=60): - raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.") - event.clear() - - if self._timings: - t4 = time.perf_counter() - - # Done with shared memory - shm_pot.unlink() - shm_pot.close() - if "rho" in sample: - shm_rho.unlink() - shm_rho.close() - - if self._timings: - t5 = time.perf_counter() - print( - f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: " - f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}" - ) + for sample_list in sample_lists: + + tar_path_hartree, name_list_hartree = sample_list["hartree"] + tar_hartree = tarfile.open(self.base_path / tar_path_hartree, "r") + + n_sample = len(name_list_hartree) + if len(sample_list["rots"]) != n_sample: + raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})") + + if "rho" in sample_list: + tar_path_rho, name_list_rho = sample_list["rho"] + tar_rho = tarfile.open(self.base_path / tar_path_rho, "r") + if len(name_list_rho) != n_sample: + raise ValueError( + f"Inconsistent number of samples between hartree and rho lists ({len(name_list_rho)} != {n_sample})" + ) + + for i_sample in range(n_sample): + + if self._timings: + t0 = time.perf_counter() + + # Load data from tar(s) + rots = sample_list["rots"][i_sample] + name_hartree = name_list_hartree[i_sample] + pot, lvec, xyzs, Zs = self._get_data(tar_hartree, name_hartree) + pot *= -1 # Unit conversion, eV -> V + if "rho" in sample_list: + name_rho = name_list_rho[i_sample] + rho, _, _, _ = self._get_data(tar_rho, name_rho) + + if self._timings: + t1 = time.perf_counter() + + # Put the data to shared memory + sample_id_pot = f"{i_proc}_{proc_id}_{i_sample}_pot" + shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot) + b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf) + b[:] = pot[:] + + if "rho" in sample_list: + sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho" + shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho) + b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf) + b[:] = rho[:] + rho_shape = rho.shape + else: + sample_id_rho = None + rho_shape = None + + if self._timings: + t2 = time.perf_counter() + + # Inform the main process of the data using the queue + self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots)) + + if self._timings: + t3 = time.perf_counter() + + # Wait until main process is done with the data + if not event.wait(timeout=60): + raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.") + event.clear() + + if self._timings: + t4 = time.perf_counter() + + # Done with shared memory + shm_pot.close() + shm_pot.unlink() + if "rho" in sample_list: + shm_rho.close() + shm_rho.unlink() + + if self._timings: + t5 = time.perf_counter() + print( + f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: " + f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}" + ) + + tar_hartree.close() + if "rho" in sample_list: + tar_rho.close() def _yield_samples(self): - from ppafm.ocl.field import ElectronDensity, HartreePotential - for _ in range(len(self)): if self._timings: