@@ -157,14 +157,22 @@ class TarDataGenerator:
157
157
n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes.
158
158
For memory usage, note that a maximum number of samples double the number of processes can be loaded into
159
159
memory at the same time.
160
+ scale_pot: The loaded Hartree potentials are scaled by this factor in order to correct the units. The yielded potential should
161
+ be in units of V. The default value of -1 works for potentials in units of eV.
162
+ scale_rho: The loaded electron densities are scaled by this factor in order to correct the units. The yielded density should
163
+ be in units of e/Å^3 with positive sign for the electron density.
160
164
"""
161
165
162
166
_timings = False
163
167
164
- def __init__ (self , samples : list [TarSampleList ], base_path : PathLike = "./" , n_proc : int = 1 ):
168
+ def __init__ (
169
+ self , samples : list [TarSampleList ], base_path : PathLike = "./" , n_proc : int = 1 , scale_pot : float = - 1 , scale_rho : float = 1
170
+ ):
165
171
self .samples = samples
166
172
self .base_path = Path (base_path )
167
173
self .n_proc = n_proc
174
+ self .scale_pot = scale_pot
175
+ self .scale_rho = scale_rho
168
176
self .pot = None
169
177
self .rho = None
170
178
@@ -239,10 +247,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
239
247
# Load data from tar(s)
240
248
rots = sample_list ["rots" ][i_sample ]
241
249
pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_list_hartree [i_sample ])
242
- pot *= - 1 # Unit conversion, eV -> V
250
+ if not np .allclose (self .scale_pot , 1 ):
251
+ pot *= self .scale_pot
243
252
total_bytes += pot .nbytes
244
253
if use_rho :
245
254
rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_list_rho [i_sample ])
255
+ if not np .allclose (self .scale_rho , 1 ):
256
+ rho *= self .scale_rho
246
257
rho_shape = rho .shape
247
258
total_bytes += rho .nbytes
248
259
else :
0 commit comments