From b2c69a850ef16a2925020d0587ce038a163cb6df Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Fri, 15 Mar 2024 15:28:58 +0200 Subject: [PATCH] Fixed electron density lattice vector in tar data generator. --- mlspm/data_generation.py | 24 +++++++++++++----------- tests/test_data_generation.py | 22 ++++++++++++++++++++-- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py index 77051c8..7161671 100644 --- a/mlspm/data_generation.py +++ b/mlspm/data_generation.py @@ -210,7 +210,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m if len(sample_list["rots"]) != n_sample: raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})") - if "rho" in sample_list: + use_rho = ("rho" in sample_list) and (sample_list["rho"] is not None) + if use_rho: tar_path_rho, name_list_rho = sample_list["rho"] tar_rho = tarfile.open(self.base_path / tar_path_rho, "r") if len(name_list_rho) != n_sample: @@ -226,11 +227,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m # Load data from tar(s) rots = sample_list["rots"][i_sample] name_hartree = name_list_hartree[i_sample] - pot, lvec, xyzs, Zs = self._get_data(tar_hartree, name_hartree) + pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_hartree) pot *= -1 # Unit conversion, eV -> V - if "rho" in sample_list: + if use_rho: name_rho = name_list_rho[i_sample] - rho, _, _, _ = self._get_data(tar_rho, name_rho) + rho, lvec_rho, _, _ = self._get_data(tar_rho, name_rho) if self._timings: t1 = time.perf_counter() @@ -241,7 +242,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf) b[:] = pot[:] - if "rho" in sample_list: + if use_rho: sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho" shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho) b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf) @@ -250,12 +251,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m else: sample_id_rho = None rho_shape = None + lvec_rho = None if self._timings: t2 = time.perf_counter() # Inform the main process of the data using the queue - self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots)) + self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots)) if self._timings: t3 = time.perf_counter() @@ -271,7 +273,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m # Done with shared memory shm_pot.close() shm_pot.unlink() - if "rho" in sample_list: + if use_rho: shm_rho.close() shm_rho.unlink() @@ -283,7 +285,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m ) tar_hartree.close() - if "rho" in sample_list: + if use_rho: tar_rho.close() def _yield_samples(self): @@ -294,16 +296,16 @@ def _yield_samples(self): t0 = time.perf_counter() # Get data info from the queue - i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec, xyzs, Zs, rots = self.q.get(timeout=200) + i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200) # Get data from the shared memory shm_pot = mp.shared_memory.SharedMemory(sample_id_pot) pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf) - pot = HartreePotential(pot, lvec) + pot = HartreePotential(pot, lvec_pot) if sample_id_rho is not None: shm_rho = mp.shared_memory.SharedMemory(sample_id_rho) rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf) - rho = ElectronDensity(rho, lvec) + rho = ElectronDensity(rho, lvec_rho) else: rho = None diff --git a/tests/test_data_generation.py b/tests/test_data_generation.py index 2bce31e..456cd28 100644 --- a/tests/test_data_generation.py +++ b/tests/test_data_generation.py @@ -98,7 +98,7 @@ def test_tar_data_generator(): rots.append([rot]) names.append(name) - sample_list = [ + sample_list_fdbm = [ { "hartree": (tar_path_hartree, names), "rho": (tar_path_rho, names), @@ -106,7 +106,7 @@ def test_tar_data_generator(): } ] - generator = TarDataGenerator(sample_list, base_path='./', n_proc=1) + generator = TarDataGenerator(sample_list_fdbm, base_path='./', n_proc=1) for i_sample, sample in enumerate(generator): assert np.allclose(sample['xyzs'], xyzs[i_sample]) @@ -117,5 +117,23 @@ def test_tar_data_generator(): assert np.allclose(sample['rho_sample'].array, rhos[i_sample]) assert np.allclose(sample['rho_sample'].lvec, lvecs[i_sample]) + + sample_list_hartree = [ + { + "hartree": (tar_path_hartree, names), + "rho": None, + "rots": rots, + } + ] + + generator = TarDataGenerator(sample_list_hartree, base_path='./', n_proc=1) + + for i_sample, sample in enumerate(generator): + assert np.allclose(sample['xyzs'], xyzs[i_sample]) + assert np.allclose(sample['Zs'], Zs[i_sample]) + assert np.allclose(sample['rot'], rots[i_sample]) + assert np.allclose(sample['qs'].array, -hartrees[i_sample]) + assert np.allclose(sample['qs'].lvec, lvecs[i_sample]) + tar_path_hartree.unlink() tar_path_rho.unlink()