Skip to content

Commit

Permalink
Fixed electron density lattice vector in tar data generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Mar 15, 2024
1 parent cd7ac83 commit b2c69a8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
24 changes: 13 additions & 11 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
if len(sample_list["rots"]) != n_sample:
raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})")

if "rho" in sample_list:
use_rho = ("rho" in sample_list) and (sample_list["rho"] is not None)
if use_rho:
tar_path_rho, name_list_rho = sample_list["rho"]
tar_rho = tarfile.open(self.base_path / tar_path_rho, "r")
if len(name_list_rho) != n_sample:
Expand All @@ -226,11 +227,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
# Load data from tar(s)
rots = sample_list["rots"][i_sample]
name_hartree = name_list_hartree[i_sample]
pot, lvec, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
pot *= -1 # Unit conversion, eV -> V
if "rho" in sample_list:
if use_rho:
name_rho = name_list_rho[i_sample]
rho, _, _, _ = self._get_data(tar_rho, name_rho)
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_rho)

if self._timings:
t1 = time.perf_counter()
Expand All @@ -241,7 +242,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
b[:] = pot[:]

if "rho" in sample_list:
if use_rho:
sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho"
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
Expand All @@ -250,12 +251,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
else:
sample_id_rho = None
rho_shape = None
lvec_rho = None

if self._timings:
t2 = time.perf_counter()

# Inform the main process of the data using the queue
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots))
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots))

if self._timings:
t3 = time.perf_counter()
Expand All @@ -271,7 +273,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
# Done with shared memory
shm_pot.close()
shm_pot.unlink()
if "rho" in sample_list:
if use_rho:
shm_rho.close()
shm_rho.unlink()

Expand All @@ -283,7 +285,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
)

tar_hartree.close()
if "rho" in sample_list:
if use_rho:
tar_rho.close()

def _yield_samples(self):
Expand All @@ -294,16 +296,16 @@ def _yield_samples(self):
t0 = time.perf_counter()

# Get data info from the queue
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec, xyzs, Zs, rots = self.q.get(timeout=200)
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)

# Get data from the shared memory
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
pot = HartreePotential(pot, lvec)
pot = HartreePotential(pot, lvec_pot)
if sample_id_rho is not None:
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
rho = ElectronDensity(rho, lvec)
rho = ElectronDensity(rho, lvec_rho)
else:
rho = None

Expand Down
22 changes: 20 additions & 2 deletions tests/test_data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ def test_tar_data_generator():
rots.append([rot])
names.append(name)

sample_list = [
sample_list_fdbm = [
{
"hartree": (tar_path_hartree, names),
"rho": (tar_path_rho, names),
"rots": rots,
}
]

generator = TarDataGenerator(sample_list, base_path='./', n_proc=1)
generator = TarDataGenerator(sample_list_fdbm, base_path='./', n_proc=1)

for i_sample, sample in enumerate(generator):
assert np.allclose(sample['xyzs'], xyzs[i_sample])
Expand All @@ -117,5 +117,23 @@ def test_tar_data_generator():
assert np.allclose(sample['rho_sample'].array, rhos[i_sample])
assert np.allclose(sample['rho_sample'].lvec, lvecs[i_sample])


sample_list_hartree = [
{
"hartree": (tar_path_hartree, names),
"rho": None,
"rots": rots,
}
]

generator = TarDataGenerator(sample_list_hartree, base_path='./', n_proc=1)

for i_sample, sample in enumerate(generator):
assert np.allclose(sample['xyzs'], xyzs[i_sample])
assert np.allclose(sample['Zs'], Zs[i_sample])
assert np.allclose(sample['rot'], rots[i_sample])
assert np.allclose(sample['qs'].array, -hartrees[i_sample])
assert np.allclose(sample['qs'].lvec, lvecs[i_sample])

tar_path_hartree.unlink()
tar_path_rho.unlink()

0 comments on commit b2c69a8

Please sign in to comment.