Skip to content

Commit b2c69a8

Browse files
committed
Fixed electron density lattice vector in tar data generator.
1 parent cd7ac83 commit b2c69a8

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

mlspm/data_generation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
210210
if len(sample_list["rots"]) != n_sample:
211211
raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})")
212212

213-
if "rho" in sample_list:
213+
use_rho = ("rho" in sample_list) and (sample_list["rho"] is not None)
214+
if use_rho:
214215
tar_path_rho, name_list_rho = sample_list["rho"]
215216
tar_rho = tarfile.open(self.base_path / tar_path_rho, "r")
216217
if len(name_list_rho) != n_sample:
@@ -226,11 +227,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
226227
# Load data from tar(s)
227228
rots = sample_list["rots"][i_sample]
228229
name_hartree = name_list_hartree[i_sample]
229-
pot, lvec, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
230+
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_hartree)
230231
pot *= -1 # Unit conversion, eV -> V
231-
if "rho" in sample_list:
232+
if use_rho:
232233
name_rho = name_list_rho[i_sample]
233-
rho, _, _, _ = self._get_data(tar_rho, name_rho)
234+
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_rho)
234235

235236
if self._timings:
236237
t1 = time.perf_counter()
@@ -241,7 +242,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
241242
b = np.ndarray(pot.shape, dtype=np.float32, buffer=shm_pot.buf)
242243
b[:] = pot[:]
243244

244-
if "rho" in sample_list:
245+
if use_rho:
245246
sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho"
246247
shm_rho = mp.shared_memory.SharedMemory(create=True, size=rho.nbytes, name=sample_id_rho)
247248
b = np.ndarray(rho.shape, dtype=np.float32, buffer=shm_rho.buf)
@@ -250,12 +251,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
250251
else:
251252
sample_id_rho = None
252253
rho_shape = None
254+
lvec_rho = None
253255

254256
if self._timings:
255257
t2 = time.perf_counter()
256258

257259
# Inform the main process of the data using the queue
258-
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec, xyzs, Zs, rots))
260+
self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots))
259261

260262
if self._timings:
261263
t3 = time.perf_counter()
@@ -271,7 +273,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
271273
# Done with shared memory
272274
shm_pot.close()
273275
shm_pot.unlink()
274-
if "rho" in sample_list:
276+
if use_rho:
275277
shm_rho.close()
276278
shm_rho.unlink()
277279

@@ -283,7 +285,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
283285
)
284286

285287
tar_hartree.close()
286-
if "rho" in sample_list:
288+
if use_rho:
287289
tar_rho.close()
288290

289291
def _yield_samples(self):
@@ -294,16 +296,16 @@ def _yield_samples(self):
294296
t0 = time.perf_counter()
295297

296298
# Get data info from the queue
297-
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec, xyzs, Zs, rots = self.q.get(timeout=200)
299+
i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)
298300

299301
# Get data from the shared memory
300302
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
301303
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
302-
pot = HartreePotential(pot, lvec)
304+
pot = HartreePotential(pot, lvec_pot)
303305
if sample_id_rho is not None:
304306
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
305307
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
306-
rho = ElectronDensity(rho, lvec)
308+
rho = ElectronDensity(rho, lvec_rho)
307309
else:
308310
rho = None
309311

tests/test_data_generation.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ def test_tar_data_generator():
9898
rots.append([rot])
9999
names.append(name)
100100

101-
sample_list = [
101+
sample_list_fdbm = [
102102
{
103103
"hartree": (tar_path_hartree, names),
104104
"rho": (tar_path_rho, names),
105105
"rots": rots,
106106
}
107107
]
108108

109-
generator = TarDataGenerator(sample_list, base_path='./', n_proc=1)
109+
generator = TarDataGenerator(sample_list_fdbm, base_path='./', n_proc=1)
110110

111111
for i_sample, sample in enumerate(generator):
112112
assert np.allclose(sample['xyzs'], xyzs[i_sample])
@@ -117,5 +117,23 @@ def test_tar_data_generator():
117117
assert np.allclose(sample['rho_sample'].array, rhos[i_sample])
118118
assert np.allclose(sample['rho_sample'].lvec, lvecs[i_sample])
119119

120+
121+
sample_list_hartree = [
122+
{
123+
"hartree": (tar_path_hartree, names),
124+
"rho": None,
125+
"rots": rots,
126+
}
127+
]
128+
129+
generator = TarDataGenerator(sample_list_hartree, base_path='./', n_proc=1)
130+
131+
for i_sample, sample in enumerate(generator):
132+
assert np.allclose(sample['xyzs'], xyzs[i_sample])
133+
assert np.allclose(sample['Zs'], Zs[i_sample])
134+
assert np.allclose(sample['rot'], rots[i_sample])
135+
assert np.allclose(sample['qs'].array, -hartrees[i_sample])
136+
assert np.allclose(sample['qs'].lvec, lvecs[i_sample])
137+
120138
tar_path_hartree.unlink()
121139
tar_path_rho.unlink()

0 commit comments

Comments
 (0)