1
1
import io
2
2
import multiprocessing as mp
3
- import multiprocessing .shared_memory
4
3
import os
5
4
import tarfile
6
5
import time
6
+ from multiprocessing .shared_memory import SharedMemory
7
7
from os import PathLike
8
8
from pathlib import Path
9
9
from typing import Optional , TypedDict
@@ -297,7 +297,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
297
297
f"Average load time: { dt / n_sample_total } s."
298
298
)
299
299
300
- def _get_queue_sample (self ):
300
+ def _get_queue_sample (
301
+ self ,
302
+ ) -> tuple [int , np .ndarray , np .ndarray , list [np .ndarray ], HartreePotential , SharedMemory , ElectronDensity , SharedMemory , str ]:
301
303
302
304
if self ._timings :
303
305
t0 = time .perf_counter ()
@@ -307,7 +309,7 @@ def _get_queue_sample(self):
307
309
if self ._timings :
308
310
t1 = time .perf_counter ()
309
311
310
- shm_pot = mp . shared_memory . SharedMemory (sample_id_pot )
312
+ shm_pot = SharedMemory (sample_id_pot )
311
313
pot = np .ndarray (pot_shape , dtype = np .float32 , buffer = shm_pot .buf )
312
314
if self .pot is None :
313
315
self .pot = HartreePotential (pot , lvec_pot )
@@ -318,12 +320,12 @@ def _get_queue_sample(self):
318
320
t2 = time .perf_counter ()
319
321
320
322
if sample_id_rho is not None :
321
- shm_rho = mp . shared_memory . SharedMemory (sample_id_rho )
323
+ shm_rho = SharedMemory (sample_id_rho )
322
324
rho = np .ndarray (rho_shape , dtype = np .float32 , buffer = shm_rho .buf )
323
325
if self .rho is None :
324
326
self .rho = ElectronDensity (rho , lvec_rho )
325
327
else :
326
- self .rho .update_array (pot , lvec_pot )
328
+ self .rho .update_array (rho , lvec_rho )
327
329
else :
328
330
shm_rho = None
329
331
rho = None
@@ -374,7 +376,7 @@ def _yield_samples(self):
374
376
375
377
376
378
def _put_to_shared_memory (array , name ):
377
- shm = mp . shared_memory . SharedMemory (create = True , size = array .nbytes , name = name )
379
+ shm = SharedMemory (create = True , size = array .nbytes , name = name )
378
380
b = np .ndarray (array .shape , dtype = np .float32 , buffer = shm .buf )
379
381
b [:] = array [:]
380
382
return shm
0 commit comments