11import io
22import multiprocessing as mp
3- import multiprocessing .shared_memory
43import os
54import tarfile
65import time
6+ from multiprocessing .shared_memory import SharedMemory
77from os import PathLike
88from pathlib import Path
99from typing import Optional , TypedDict
@@ -297,7 +297,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
297297 f"Average load time: { dt / n_sample_total } s."
298298 )
299299
300- def _get_queue_sample (self ):
300+ def _get_queue_sample (
301+ self ,
302+ ) -> tuple [int , np .ndarray , np .ndarray , list [np .ndarray ], HartreePotential , SharedMemory , ElectronDensity , SharedMemory , str ]:
301303
302304 if self ._timings :
303305 t0 = time .perf_counter ()
@@ -307,7 +309,7 @@ def _get_queue_sample(self):
307309 if self ._timings :
308310 t1 = time .perf_counter ()
309311
310- shm_pot = mp . shared_memory . SharedMemory (sample_id_pot )
312+ shm_pot = SharedMemory (sample_id_pot )
311313 pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
312314 if self .pot is None :
313315 self .pot = HartreePotential (pot , lvec_pot )
@@ -318,12 +320,12 @@ def _get_queue_sample(self):
318320 t2 = time .perf_counter ()
319321
320322 if sample_id_rho is not None :
321- shm_rho = mp . shared_memory . SharedMemory (sample_id_rho )
323+ shm_rho = SharedMemory (sample_id_rho )
322324 rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
323325 if self .rho is None :
324326 self .rho = ElectronDensity (rho , lvec_rho )
325327 else :
326- self .rho .update_array (pot , lvec_pot )
328+ self .rho .update_array (rho , lvec_rho )
327329 else :
328330 shm_rho = None
329331 rho = None
@@ -374,7 +376,7 @@ def _yield_samples(self):
374376
375377
376378def _put_to_shared_memory (array , name ):
377- shm = mp . shared_memory . SharedMemory (create = True , size = array .nbytes , name = name )
379+ shm = SharedMemory (create = True , size = array .nbytes , name = name )
378380 b = np .ndarray (array .shape , dtype = np .float32 , buffer = shm .buf )
379381 b [:] = array [:]
380382 return shm
0 commit comments