Skip to content

Commit

Permalink
Changed TarDataGenerator sample format.
Browse files Browse the repository at this point in the history
The samples are now divided by the tar file so that the access pattern is sequential.
  • Loading branch information
NikoOinonen committed Mar 11, 2024
1 parent 22bfd3b commit 4f97960
Showing 1 changed file with 110 additions and 89 deletions.
199 changes: 110 additions & 89 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import numpy as np
from PIL import Image
from ppafm.ocl.field import ElectronDensity, HartreePotential


class TarWriter:
"""
Expand Down Expand Up @@ -109,18 +111,21 @@ def get_tarinfo(fname: str, file_bytes: io.BytesIO):
info.mtime = time.time()
return info

class TarSample(TypedDict, total=False):

class TarSampleList(TypedDict, total=False):
"""
- ``'hartree'``: Path to the Hartree potential. First item in the tuple is the path
to the tar file relative to ``base_path``, and second entry is the tar file member name.
- ``'rho'``: (Optional) Path to the electron density. First item in the tuple is the path
to the tar file relative to ``base_path``, and second entry is the tar file member name.
- ``'rots'``: List of rotations to generate for the sample.
- ``'hartree'``: Paths to the Hartree potentials. First item in the tuple is the path to the tar file,
and second entry is a list of tar file member names.
- ``'rho'``: (Optional) Paths to the electron densities. First item in the tuple is the path to the tar
file, and second entry is a list tar file member names.
- ``'rots'``: List of rotations for each sample.
"""
hartree: tuple[str, str]
rho: tuple[str, str]

hartree: tuple[PathLike, list[str]]
rho: tuple[PathLike, list[str]]
rots: list[np.ndarray]


class TarDataGenerator:
"""
Iterable that loads data from tar archives with data saved in npz format for generating samples
Expand All @@ -135,20 +140,19 @@ class TarDataGenerator:
- ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``.
Arguments:
samples: List of sample dicts as :class:`TarSample`. If ``rho`` is present in the dict, the full-density-based model
is used in the simulation. Otherwise Lennard-Jones with Hartree electrostatics is used.
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.
base_path: Path to the directory with the tar files.
n_proc: Number of parallel processes for loading data. The samples get divided evenly over the processes.
n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
"""

_timings = False

def __init__(self, samples: list[TarSample], base_path: PathLike = "./", n_proc: int = 1):
def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_proc: int = 1):
self.samples = samples
self.base_path = Path(base_path)
self.n_proc = n_proc

def __len__(self):
def __len__(self) -> int:
"""Total number of samples (including rotations)"""
return sum([len(s["rots"]) for s in self.samples])

Expand All @@ -170,91 +174,108 @@ def __iter__(self):
def __next__(self):
return next(self.iterator)

def _get_data(self, tar_path: PathLike, name: str):
tar_path = self.base_path / tar_path
with tarfile.open(tar_path, "r") as f:
data = np.load(f.extractfile(name))
array = data["data"]
origin = data["origin"]
lattice = data["lattice"]
xyzs = data["xyz"]
Zs = data["Z"]
lvec = np.concatenate([origin[None, :], lattice], axis=0)
def _get_data(self, tar: tarfile.TarFile, name: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
data = np.load(tar.extractfile(name))
array = data["data"]
origin = data["origin"]
lattice = data["lattice"]
xyzs = data["xyz"]
Zs = data["Z"]
lvec = np.concatenate([origin[None, :], lattice], axis=0)
return array, lvec, xyzs, Zs

def _load_samples(self, samples: list[TarSample], i_proc: int, event: mp.Event):
def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: mp.Event):

proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:]
print(f"Starting worker {i_proc}, id {proc_id}")

for i, sample in enumerate(samples):

if self._timings:
t0 = time.perf_counter()

# Load data from tar(s)
rots = sample["rots"]
hartree_tar_path, name = sample["hartree"]
pot, lvec, xyzs, Zs = self._get_data(hartree_tar_path, name)
pot *= -1 # Unit conversion, eV -> V
if "rho" in sample:
rho_tar_path, name = sample["rho"]
rho, _, _, _ = self._get_data(rho_tar_path, name)

if self._timings:
t1 = time.perf_counter()

# Put the data to shared memory
sample_id_pot = f"{i_proc}_{proc_id}_{i}_pot"
shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot)
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
b[:] = pot[:]

if "rho" in sample:
sample_id_rho = f"{i_proc}_{proc_id}_{i}__rho"
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
b[:] = rho[:]
rho_shape = rho.shape
else:
sample_id_rho = None
rho_shape = None

if self._timings:
t2 = time.perf_counter()

# Inform the main process of the data using the queue
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots))

if self._timings:
t3 = time.perf_counter()

# Wait until main process is done with the data
if not event.wait(timeout=60):
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
event.clear()

if self._timings:
t4 = time.perf_counter()

# Done with shared memory
shm_pot.unlink()
shm_pot.close()
if "rho" in sample:
shm_rho.unlink()
shm_rho.close()

if self._timings:
t5 = time.perf_counter()
print(
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: "
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}"
)
for sample_list in sample_lists:

tar_path_hartree, name_list_hartree = sample_list["hartree"]
tar_hartree = tarfile.open(self.base_path / tar_path_hartree, "r")

n_sample = len(name_list_hartree)
if len(sample_list["rots"]) != n_sample:
raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})")

if "rho" in sample_list:
tar_path_rho, name_list_rho = sample_list["rho"]
tar_rho = tarfile.open(self.base_path / tar_path_rho, "r")
if len(name_list_rho) != n_sample:
raise ValueError(
f"Inconsistent number of samples between hartree and rho lists ({len(name_list_rho)} != {n_sample})"
)

for i_sample in range(n_sample):

if self._timings:
t0 = time.perf_counter()

# Load data from tar(s)
rots = sample_list["rots"][i_sample]
name_hartree = name_list_hartree[i_sample]
pot, lvec, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
pot *= -1 # Unit conversion, eV -> V
if "rho" in sample_list:
name_rho = name_list_rho[i_sample]
rho, _, _, _ = self._get_data(tar_rho, name_rho)

if self._timings:
t1 = time.perf_counter()

# Put the data to shared memory
sample_id_pot = f"{i_proc}_{proc_id}_{i_sample}_pot"
shm_pot = mp.shared_memory.SharedMemory(create=True, size=pot.nbytes, name=sample_id_pot)
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
b[:] = pot[:]

if "rho" in sample_list:
sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho"
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
b[:] = rho[:]
rho_shape = rho.shape
else:
sample_id_rho = None
rho_shape = None

if self._timings:
t2 = time.perf_counter()

# Inform the main process of the data using the queue
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots))

if self._timings:
t3 = time.perf_counter()

# Wait until main process is done with the data
if not event.wait(timeout=60):
raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.")
event.clear()

if self._timings:
t4 = time.perf_counter()

# Done with shared memory
shm_pot.close()
shm_pot.unlink()
if "rho" in sample_list:
shm_rho.close()
shm_rho.unlink()

if self._timings:
t5 = time.perf_counter()
print(
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Queue / Wait / Unlink: "
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} / {t4 - t3:.5f} / {t5 - t4:.5f}"
)

tar_hartree.close()
if "rho" in sample_list:
tar_rho.close()

def _yield_samples(self):

from ppafm.ocl.field import ElectronDensity, HartreePotential

for _ in range(len(self)):

if self._timings:
Expand Down

0 comments on commit 4f97960

Please sign in to comment.