22
22
23
23
from collections import namedtuple
24
24
from collections .abc import Sequence
25
+ from typing import cast
25
26
26
27
import cloudpickle
27
28
import numpy as np
31
32
from rich .theme import Theme
32
33
from threadpoolctl import threadpool_limits
33
34
35
+ from pymc .backends .zarr import ZarrChain
34
36
from pymc .blocking import DictToArrayBijection
35
37
from pymc .exceptions import SamplingError
36
38
from pymc .util import (
@@ -104,13 +106,25 @@ def __init__(
104
106
tune : int ,
105
107
rng_state : RandomGeneratorState ,
106
108
blas_cores ,
109
+ chain : int ,
110
+ zarr_chains : list [ZarrChain ] | bytes | None = None ,
111
+ zarr_chains_is_pickled : bool = False ,
107
112
):
108
113
# For some strange reason, spawn multiprocessing doesn't copy the rng
109
114
# seed sequence, so we have to rebuild it from scratch
110
115
rng = random_generator_from_state (rng_state )
111
116
self ._msg_pipe = msg_pipe
112
117
self ._step_method = step_method
113
118
self ._step_method_is_pickled = step_method_is_pickled
119
+ self .chain = chain
120
+ self ._zarr_recording = False
121
+ self ._zarr_chain : ZarrChain | None = None
122
+ if zarr_chains_is_pickled :
123
+ self ._zarr_chain = cloudpickle .loads (zarr_chains )[self .chain ]
124
+ elif zarr_chains is not None :
125
+ self ._zarr_chain = cast (list [ZarrChain ], zarr_chains )[self .chain ]
126
+ self ._zarr_recording = self ._zarr_chain is not None
127
+
114
128
self ._shared_point = shared_point
115
129
self ._rng = rng
116
130
self ._draws = draws
@@ -135,6 +149,7 @@ def run(self):
135
149
# We do not create this in __init__, as pickling this
136
150
# would destroy the shared memory.
137
151
self ._unpickle_step_method ()
152
+ self ._link_step_to_zarrchain ()
138
153
self ._point = self ._make_numpy_refs ()
139
154
self ._start_loop ()
140
155
except KeyboardInterrupt :
@@ -148,6 +163,10 @@ def run(self):
148
163
finally :
149
164
self ._msg_pipe .close ()
150
165
166
+ def _link_step_to_zarrchain (self ):
167
+ if self ._zarr_recording :
168
+ self ._zarr_chain .link_stepper (self ._step_method )
169
+
151
170
def _wait_for_abortion (self ):
152
171
while True :
153
172
msg = self ._recv_msg ()
@@ -170,6 +189,7 @@ def _recv_msg(self):
170
189
return self ._msg_pipe .recv ()
171
190
172
191
def _start_loop (self ):
192
+ zarr_recording = self ._zarr_recording
173
193
self ._step_method .set_rng (self ._rng )
174
194
175
195
draw = 0
@@ -199,6 +219,8 @@ def _start_loop(self):
199
219
if msg [0 ] == "abort" :
200
220
raise KeyboardInterrupt ()
201
221
elif msg [0 ] == "write_next" :
222
+ if zarr_recording :
223
+ self ._zarr_chain .record (point , stats )
202
224
self ._write_point (point )
203
225
is_last = draw + 1 == self ._draws + self ._tune
204
226
self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats ))
@@ -225,6 +247,8 @@ def __init__(
225
247
start : dict [str , np .ndarray ],
226
248
blas_cores ,
227
249
mp_ctx ,
250
+ zarr_chains : list [ZarrChain ] | None = None ,
251
+ zarr_chains_pickled : bytes | None = None ,
228
252
):
229
253
self .chain = chain
230
254
process_name = f"worker_chain_{ chain } "
@@ -247,6 +271,16 @@ def __init__(
247
271
self ._readable = True
248
272
self ._num_samples = 0
249
273
274
+ zarr_chains_send : list [ZarrChain ] | bytes | None = None
275
+ if zarr_chains_pickled is not None :
276
+ zarr_chains_send = zarr_chains_pickled
277
+ elif zarr_chains is not None :
278
+ if mp_ctx .get_start_method () == "spawn" :
279
+ raise ValueError (
280
+ "please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
281
+ )
282
+ zarr_chains_send = zarr_chains
283
+
250
284
if step_method_pickled is not None :
251
285
step_method_send = step_method_pickled
252
286
else :
@@ -270,6 +304,9 @@ def __init__(
270
304
tune ,
271
305
get_state_from_generator (rng ),
272
306
blas_cores ,
307
+ self .chain ,
308
+ zarr_chains_send ,
309
+ zarr_chains_pickled is not None ,
273
310
),
274
311
)
275
312
self ._process .start ()
@@ -392,6 +429,7 @@ def __init__(
392
429
progressbar_theme : Theme | None = default_progress_theme ,
393
430
blas_cores : int | None = None ,
394
431
mp_ctx = None ,
432
+ zarr_chains : list [ZarrChain ] | None = None ,
395
433
):
396
434
if any (len (arg ) != chains for arg in [rngs , start_points ]):
397
435
raise ValueError (f"Number of rngs and start_points must be { chains } ." )
@@ -412,8 +450,15 @@ def __init__(
412
450
mp_ctx = multiprocessing .get_context (mp_ctx )
413
451
414
452
step_method_pickled = None
453
+ zarr_chains_pickled = None
454
+ self .zarr_recording = False
455
+ if zarr_chains is not None :
456
+ assert all (isinstance (zarr_chain , ZarrChain ) for zarr_chain in zarr_chains )
457
+ self .zarr_recording = True
415
458
if mp_ctx .get_start_method () != "fork" :
416
459
step_method_pickled = cloudpickle .dumps (step_method , protocol = - 1 )
460
+ if zarr_chains is not None :
461
+ zarr_chains_pickled = cloudpickle .dumps (zarr_chains , protocol = - 1 )
417
462
418
463
self ._samplers = [
419
464
ProcessAdapter (
@@ -426,6 +471,8 @@ def __init__(
426
471
start ,
427
472
blas_cores ,
428
473
mp_ctx ,
474
+ zarr_chains = zarr_chains ,
475
+ zarr_chains_pickled = zarr_chains_pickled ,
429
476
)
430
477
for chain , rng , start in zip (range (chains ), rngs , start_points )
431
478
]
0 commit comments