Skip to content

Commit b0b5b2a

Browse files
committed
Added option for scaling factors for the loaded arrays in the tar data generator.
1 parent cb7d186 commit b0b5b2a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

mlspm/data_generation.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,22 @@ class TarDataGenerator:
157157
n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
158158
For memory usage, note that a maximum number of samples double the number of processes can be loaded into
159159
memory at the same time.
160+
scale_pot: The loaded Hartree potentials are scaled by this factor in order to correct the units. The yielded potential should
161+
be in units of V. The default value of -1 works for potentials in units of eV.
162+
scale_rho: The loaded electron densities are scaled by this factor in order to correct the units. The yielded density should
163+
be in units of e/Å^3 with positive sign for the electron density.
160164
"""
161165

162166
_timings = False
163167

164-
def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_proc: int = 1):
168+
def __init__(
169+
self, samples: list[TarSampleList], base_path: PathLike = "./", n_proc: int = 1, scale_pot: float = -1, scale_rho: float = 1
170+
):
165171
self.samples = samples
166172
self.base_path = Path(base_path)
167173
self.n_proc = n_proc
174+
self.scale_pot = scale_pot
175+
self.scale_rho = scale_rho
168176
self.pot = None
169177
self.rho = None
170178

@@ -239,10 +247,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
239247
# Load data from tar(s)
240248
rots = sample_list["rots"][i_sample]
241249
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample])
242-
pot *= -1 # Unit conversion, eV -> V
250+
if not np.allclose(self.scale_pot, 1):
251+
pot *= self.scale_pot
243252
total_bytes += pot.nbytes
244253
if use_rho:
245254
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample])
255+
if not np.allclose(self.scale_rho, 1):
256+
rho *= self.scale_rho
246257
rho_shape = rho.shape
247258
total_bytes += rho.nbytes
248259
else:

0 commit comments

Comments
 (0)