@@ -210,7 +210,8 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
210
210
if len (sample_list ["rots" ]) != n_sample :
211
211
raise ValueError (f"Inconsistent number of rotations in sample list ({ len (sample_list ['rots' ])} != { n_sample } )" )
212
212
213
- if "rho" in sample_list :
213
+ use_rho = ("rho" in sample_list ) and (sample_list ["rho" ] is not None )
214
+ if use_rho :
214
215
tar_path_rho , name_list_rho = sample_list ["rho" ]
215
216
tar_rho = tarfile .open (self .base_path / tar_path_rho , "r" )
216
217
if len (name_list_rho ) != n_sample :
@@ -226,11 +227,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
226
227
# Load data from tar(s)
227
228
rots = sample_list ["rots" ][i_sample ]
228
229
name_hartree = name_list_hartree [i_sample ]
229
- pot , lvec , xyzs , Zs = self ._get_data (tar_hartree , name_hartree )
230
+ pot , lvec_pot , xyzs , Zs = self ._get_data (tar_hartree , name_hartree )
230
231
pot *= - 1 # Unit conversion, eV -> V
231
- if "rho" in sample_list :
232
+ if use_rho :
232
233
name_rho = name_list_rho [i_sample ]
233
- rho , _ , _ , _ = self ._get_data (tar_rho , name_rho )
234
+ rho , lvec_rho , _ , _ = self ._get_data (tar_rho , name_rho )
234
235
235
236
if self ._timings :
236
237
t1 = time .perf_counter ()
@@ -241,7 +242,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
241
242
b = np .ndarray (pot .shape , dtype = np .float32 , buffer = shm_pot .buf )
242
243
b [:] = pot [:]
243
244
244
- if "rho" in sample_list :
245
+ if use_rho :
245
246
sample_id_rho = f"{ i_proc } _{ proc_id } _{ i_sample } _rho"
246
247
shm_rho = mp .shared_memory .SharedMemory (create = True , size = rho .nbytes , name = sample_id_rho )
247
248
b = np .ndarray (rho .shape , dtype = np .float32 , buffer = shm_rho .buf )
@@ -250,12 +251,13 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
250
251
else :
251
252
sample_id_rho = None
252
253
rho_shape = None
254
+ lvec_rho = None
253
255
254
256
if self ._timings :
255
257
t2 = time .perf_counter ()
256
258
257
259
# Inform the main process of the data using the queue
258
- self .q .put ((i_proc , sample_id_pot , sample_id_rho , pot .shape , rho_shape , lvec , xyzs , Zs , rots ))
260
+ self .q .put ((i_proc , sample_id_pot , sample_id_rho , pot .shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots ))
259
261
260
262
if self ._timings :
261
263
t3 = time .perf_counter ()
@@ -271,7 +273,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
271
273
# Done with shared memory
272
274
shm_pot .close ()
273
275
shm_pot .unlink ()
274
- if "rho" in sample_list :
276
+ if use_rho :
275
277
shm_rho .close ()
276
278
shm_rho .unlink ()
277
279
@@ -283,7 +285,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
283
285
)
284
286
285
287
tar_hartree .close ()
286
- if "rho" in sample_list :
288
+ if use_rho :
287
289
tar_rho .close ()
288
290
289
291
def _yield_samples (self ):
@@ -294,16 +296,16 @@ def _yield_samples(self):
294
296
t0 = time .perf_counter ()
295
297
296
298
# Get data info from the queue
297
- i_proc , sample_id_pot , sample_id_rho , pot_shape , rho_shape , lvec , xyzs , Zs , rots = self .q .get (timeout = 200 )
299
+ i_proc , sample_id_pot , sample_id_rho , pot_shape , rho_shape , lvec_pot , lvec_rho , xyzs , Zs , rots = self .q .get (timeout = 200 )
298
300
299
301
# Get data from the shared memory
300
302
shm_pot = mp .shared_memory .SharedMemory (sample_id_pot )
301
303
pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
302
- pot = HartreePotential (pot , lvec )
304
+ pot = HartreePotential (pot , lvec_pot )
303
305
if sample_id_rho is not None :
304
306
shm_rho = mp .shared_memory .SharedMemory (sample_id_rho )
305
307
rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
306
- rho = ElectronDensity (rho , lvec )
308
+ rho = ElectronDensity (rho , lvec_rho )
307
309
else :
308
310
rho = None
309
311
0 commit comments