Skip to content

Commit cb7d186

Browse files
committed
Correct update of rho array in tar generator.
1 parent 9af2380 commit cb7d186

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

mlspm/data_generation.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import io
22
import multiprocessing as mp
3-
import multiprocessing.shared_memory
43
import os
54
import tarfile
65
import time
6+
from multiprocessing.shared_memory import SharedMemory
77
from os import PathLike
88
from pathlib import Path
99
from typing import Optional, TypedDict
@@ -297,7 +297,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
297297
f"Average load time: {dt / n_sample_total}s."
298298
)
299299

300-
def _get_queue_sample(self):
300+
def _get_queue_sample(
301+
self,
302+
) -> tuple[int, np.ndarray, np.ndarray, list[np.ndarray], HartreePotential, SharedMemory, ElectronDensity, SharedMemory, str]:
301303

302304
if self._timings:
303305
t0 = time.perf_counter()
@@ -307,7 +309,7 @@ def _get_queue_sample(self):
307309
if self._timings:
308310
t1 = time.perf_counter()
309311

310-
shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
312+
shm_pot = SharedMemory(sample_id_pot)
311313
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
312314
if self.pot is None:
313315
self.pot = HartreePotential(pot, lvec_pot)
@@ -318,12 +320,12 @@ def _get_queue_sample(self):
318320
t2 = time.perf_counter()
319321

320322
if sample_id_rho is not None:
321-
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
323+
shm_rho = SharedMemory(sample_id_rho)
322324
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
323325
if self.rho is None:
324326
self.rho = ElectronDensity(rho, lvec_rho)
325327
else:
326-
self.rho.update_array(pot, lvec_pot)
328+
self.rho.update_array(rho, lvec_rho)
327329
else:
328330
shm_rho = None
329331
rho = None
@@ -374,7 +376,7 @@ def _yield_samples(self):
374376

375377

376378
def _put_to_shared_memory(array, name):
377-
shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes, name=name)
379+
shm = SharedMemory(create=True, size=array.nbytes, name=name)
378380
b = np.ndarray(array.shape, dtype=np.float32, buffer=shm.buf)
379381
b[:] = array[:]
380382
return shm

0 commit comments

Comments
 (0)