@@ -157,14 +157,22 @@ class TarDataGenerator:
157157 n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
158158 For memory usage, note that a maximum number of samples double the number of processes can be loaded into
159159 memory at the same time.
160+ scale_pot: The loaded Hartree potentials are scaled by this factor in order to correct the units. The yielded potential should
161+ be in units of V. The default value of -1 works for potentials in units of eV.
162+ scale_rho: The loaded electron densities are scaled by this factor in order to correct the units. The yielded density should
163+ be in units of e/Å^3 with positive sign for the electron density.
160164 """
161165
162166 _timings = False
163167
164- def __init__ (self , samples : list [TarSampleList ], base_path : PathLike = "./" , n_proc : int = 1 ):
168+ def __init__ (
169+ self , samples : list [TarSampleList ], base_path : PathLike = "./" , n_proc : int = 1 , scale_pot : float = - 1 , scale_rho : float = 1
170+ ):
165171 self .samples = samples
166172 self .base_path = Path (base_path )
167173 self .n_proc = n_proc
174+ self .scale_pot = scale_pot
175+ self .scale_rho = scale_rho
168176 self .pot = None
169177 self .rho = None
170178
@@ -239,10 +247,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
239247 # Load data from tar(s)
240248 rots = sample_list ["rots" ][i_sample ]
241249 pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_list_hartree [i_sample ])
242- pot *= - 1 # Unit conversion, eV -> V
250+ if not np .allclose (self .scale_pot , 1 ):
251+ pot *= self .scale_pot
243252 total_bytes += pot .nbytes
244253 if use_rho :
245254 rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_list_rho [i_sample ])
255+ if not np .allclose (self .scale_rho , 1 ):
256+ rho *= self .scale_rho
246257 rho_shape = rho .shape
247258 total_bytes += rho .nbytes
248259 else :
0 commit comments