Skip to content

Commit 35cdfa6

Browse files
committed
Write sampling state periodically
1 parent 147b92e commit 35cdfa6

File tree

4 files changed

+186
-16
lines changed

4 files changed

+186
-16
lines changed

pymc/sampling/mcmc.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Any,
2727
Literal,
2828
TypeAlias,
29+
cast,
2930
overload,
3031
)
3132

@@ -40,6 +41,7 @@
4041
from rich.theme import Theme
4142
from threadpoolctl import threadpool_limits
4243
from typing_extensions import Protocol
44+
from zarr.storage import MemoryStore
4345

4446
import pymc as pm
4547

@@ -50,7 +52,7 @@
5052
find_observations,
5153
)
5254
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
53-
from pymc.backends.zarr import ZarrTrace
55+
from pymc.backends.zarr import ZarrChain, ZarrTrace
5456
from pymc.blocking import DictToArrayBijection
5557
from pymc.exceptions import SamplingError
5658
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
@@ -1275,6 +1277,8 @@ def _iter_sample(
12751277
step.set_rng(rng)
12761278

12771279
point = start
1280+
if isinstance(trace, ZarrChain):
1281+
trace.link_stepper(step)
12781282

12791283
try:
12801284
step.tune = bool(tune)
@@ -1297,12 +1301,18 @@ def _iter_sample(
12971301

12981302
yield diverging
12991303
except KeyboardInterrupt:
1304+
if isinstance(trace, ZarrChain):
1305+
trace.record_sampling_state(step=step)
13001306
trace.close()
13011307
raise
13021308
except BaseException:
1309+
if isinstance(trace, ZarrChain):
1310+
trace.record_sampling_state(step=step)
13031311
trace.close()
13041312
raise
13051313
else:
1314+
if isinstance(trace, ZarrChain):
1315+
trace.record_sampling_state(step=step)
13061316
trace.close()
13071317

13081318

@@ -1361,6 +1371,19 @@ def _mp_sample(
13611371

13621372
# We did draws += tune in pm.sample
13631373
draws -= tune
1374+
zarr_chains: list[ZarrChain] | None = None
1375+
zarr_recording = False
1376+
if all(isinstance(trace, ZarrChain) for trace in traces):
1377+
if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore):
1378+
warnings.warn(
1379+
"Parallel sampling with MemoryStore zarr store wont write the processes "
1380+
"step method sampling state. If you wish to be able to access the step "
1381+
"method sampling state, please use a different storage backend, e.g. "
1382+
"DirectoryStore or ZipStore"
1383+
)
1384+
else:
1385+
zarr_chains = cast(list[ZarrChain], traces)
1386+
zarr_recording = True
13641387

13651388
sampler = ps.ParallelSampler(
13661389
draws=draws,
@@ -1374,13 +1397,16 @@ def _mp_sample(
13741397
progressbar_theme=progressbar_theme,
13751398
blas_cores=blas_cores,
13761399
mp_ctx=mp_ctx,
1400+
zarr_chains=zarr_chains,
13771401
)
13781402
try:
13791403
try:
13801404
with sampler:
13811405
for draw in sampler:
13821406
strace = traces[draw.chain]
1383-
strace.record(draw.point, draw.stats)
1407+
if not zarr_recording:
1408+
# Zarr recording happens in each process
1409+
strace.record(draw.point, draw.stats)
13841410
log_warning_stats(draw.stats)
13851411

13861412
if callback is not None:

pymc/sampling/parallel.py

+47
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from collections import namedtuple
2424
from collections.abc import Sequence
25+
from typing import cast
2526

2627
import cloudpickle
2728
import numpy as np
@@ -31,6 +32,7 @@
3132
from rich.theme import Theme
3233
from threadpoolctl import threadpool_limits
3334

35+
from pymc.backends.zarr import ZarrChain
3436
from pymc.blocking import DictToArrayBijection
3537
from pymc.exceptions import SamplingError
3638
from pymc.util import (
@@ -104,13 +106,25 @@ def __init__(
104106
tune: int,
105107
rng_state: RandomGeneratorState,
106108
blas_cores,
109+
chain: int,
110+
zarr_chains: list[ZarrChain] | bytes | None = None,
111+
zarr_chains_is_pickled: bool = False,
107112
):
108113
# For some strange reason, spawn multiprocessing doesn't copy the rng
109114
# seed sequence, so we have to rebuild it from scratch
110115
rng = random_generator_from_state(rng_state)
111116
self._msg_pipe = msg_pipe
112117
self._step_method = step_method
113118
self._step_method_is_pickled = step_method_is_pickled
119+
self.chain = chain
120+
self._zarr_recording = False
121+
self._zarr_chain: ZarrChain | None = None
122+
if zarr_chains_is_pickled:
123+
self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain]
124+
elif zarr_chains is not None:
125+
self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain]
126+
self._zarr_recording = self._zarr_chain is not None
127+
114128
self._shared_point = shared_point
115129
self._rng = rng
116130
self._draws = draws
@@ -135,6 +149,7 @@ def run(self):
135149
# We do not create this in __init__, as pickling this
136150
# would destroy the shared memory.
137151
self._unpickle_step_method()
152+
self._link_step_to_zarrchain()
138153
self._point = self._make_numpy_refs()
139154
self._start_loop()
140155
except KeyboardInterrupt:
@@ -148,6 +163,10 @@ def run(self):
148163
finally:
149164
self._msg_pipe.close()
150165

166+
def _link_step_to_zarrchain(self):
167+
if self._zarr_recording:
168+
self._zarr_chain.link_stepper(self._step_method)
169+
151170
def _wait_for_abortion(self):
152171
while True:
153172
msg = self._recv_msg()
@@ -170,6 +189,7 @@ def _recv_msg(self):
170189
return self._msg_pipe.recv()
171190

172191
def _start_loop(self):
192+
zarr_recording = self._zarr_recording
173193
self._step_method.set_rng(self._rng)
174194

175195
draw = 0
@@ -199,6 +219,8 @@ def _start_loop(self):
199219
if msg[0] == "abort":
200220
raise KeyboardInterrupt()
201221
elif msg[0] == "write_next":
222+
if zarr_recording:
223+
self._zarr_chain.record(point, stats)
202224
self._write_point(point)
203225
is_last = draw + 1 == self._draws + self._tune
204226
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
@@ -225,6 +247,8 @@ def __init__(
225247
start: dict[str, np.ndarray],
226248
blas_cores,
227249
mp_ctx,
250+
zarr_chains: list[ZarrChain] | None = None,
251+
zarr_chains_pickled: bytes | None = None,
228252
):
229253
self.chain = chain
230254
process_name = f"worker_chain_{chain}"
@@ -247,6 +271,16 @@ def __init__(
247271
self._readable = True
248272
self._num_samples = 0
249273

274+
zarr_chains_send: list[ZarrChain] | bytes | None = None
275+
if zarr_chains_pickled is not None:
276+
zarr_chains_send = zarr_chains_pickled
277+
elif zarr_chains is not None:
278+
if mp_ctx.get_start_method() == "spawn":
279+
raise ValueError(
280+
"please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
281+
)
282+
zarr_chains_send = zarr_chains
283+
250284
if step_method_pickled is not None:
251285
step_method_send = step_method_pickled
252286
else:
@@ -270,6 +304,9 @@ def __init__(
270304
tune,
271305
get_state_from_generator(rng),
272306
blas_cores,
307+
self.chain,
308+
zarr_chains_send,
309+
zarr_chains_pickled is not None,
273310
),
274311
)
275312
self._process.start()
@@ -392,6 +429,7 @@ def __init__(
392429
progressbar_theme: Theme | None = default_progress_theme,
393430
blas_cores: int | None = None,
394431
mp_ctx=None,
432+
zarr_chains: list[ZarrChain] | None = None,
395433
):
396434
if any(len(arg) != chains for arg in [rngs, start_points]):
397435
raise ValueError(f"Number of rngs and start_points must be {chains}.")
@@ -412,8 +450,15 @@ def __init__(
412450
mp_ctx = multiprocessing.get_context(mp_ctx)
413451

414452
step_method_pickled = None
453+
zarr_chains_pickled = None
454+
self.zarr_recording = False
455+
if zarr_chains is not None:
456+
assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains)
457+
self.zarr_recording = True
415458
if mp_ctx.get_start_method() != "fork":
416459
step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
460+
if zarr_chains is not None:
461+
zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1)
417462

418463
self._samplers = [
419464
ProcessAdapter(
@@ -426,6 +471,8 @@ def __init__(
426471
start,
427472
blas_cores,
428473
mp_ctx,
474+
zarr_chains=zarr_chains,
475+
zarr_chains_pickled=zarr_chains_pickled,
429476
)
430477
for chain, rng, start in zip(range(chains), rngs, start_points)
431478
]

pymc/sampling/population.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
2828

2929
from pymc.backends.base import BaseTrace
30+
from pymc.backends.zarr import ZarrChain
3031
from pymc.initial_point import PointType
3132
from pymc.model import Model, modelcontext
3233
from pymc.stats.convergence import log_warning_stats
@@ -36,6 +37,7 @@
3637
PopulationArrayStepShared,
3738
StatsType,
3839
)
40+
from pymc.step_methods.compound import StepMethodState
3941
from pymc.step_methods.metropolis import DEMetropolis
4042
from pymc.util import CustomProgress
4143

@@ -81,6 +83,11 @@ def _sample_population(
8183
Show progress bars? (defaults to True)
8284
parallelize : bool
8385
Setting for multiprocess parallelization
86+
traces : Sequence[BaseTrace]
87+
A sequences of chain traces where the sampling results will be stored. Can be
88+
a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
89+
:py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
90+
:py:class:`~pymc.backends.zarr.ZarrChain`.
8491
"""
8592
warn_population_size(
8693
step=step,
@@ -263,6 +270,9 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress):
263270
# receiving a None is the signal to exit
264271
if incoming is None:
265272
break
273+
elif incoming == "sampling_state":
274+
secondary_end.send((c, stepper.sampling_state))
275+
continue
266276
tune_stop, population = incoming
267277
if tune_stop:
268278
stepper.stop_tuning()
@@ -307,6 +317,14 @@ def step(self, tune_stop: bool, population) -> list[tuple[PointType, StatsType]]
307317
updates.append(self._steppers[c].step(population[c]))
308318
return updates
309319

320+
def request_sampling_state(self, chain) -> StepMethodState:
321+
if self.is_parallelized:
322+
self._primary_ends[chain].send(("sampling_state",))
323+
_, sampling_state = self._primary_ends[chain].recv()
324+
else:
325+
sampling_state = self._steppers[chain].sampling_state
326+
return sampling_state
327+
310328

311329
def _prepare_iter_population(
312330
*,
@@ -332,6 +350,11 @@ def _prepare_iter_population(
332350
Start points for each chain
333351
parallelize : bool
334352
Setting for multiprocess parallelization
353+
traces : Sequence[BaseTrace]
354+
A sequences of chain traces where the sampling results will be stored. Can be
355+
a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
356+
:py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
357+
:py:class:`~pymc.backends.zarr.ZarrChain`.
335358
tune : int
336359
Number of iterations to tune.
337360
rngs: sequence of random Generators
@@ -411,8 +434,11 @@ def _iter_population(
411434
the helper object for (parallelized) stepping of chains
412435
steppers : list
413436
The step methods for each chain
414-
traces : list
415-
Traces for each chain
437+
traces : Sequence[BaseTrace]
438+
A sequences of chain traces where the sampling results will be stored. Can be
439+
a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
440+
:py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
441+
:py:class:`~pymc.backends.zarr.ZarrChain`.
416442
points : list
417443
population of chain states
418444
@@ -432,8 +458,11 @@ def _iter_population(
432458
# apply the update to the points and record to the traces
433459
for c, strace in enumerate(traces):
434460
points[c], stats = updates[c]
435-
strace.record(points[c], stats)
461+
flushed = strace.record(points[c], stats)
436462
log_warning_stats(stats)
463+
if flushed and isinstance(strace, ZarrChain):
464+
sampling_state = popstep.request_sampling_state(c)
465+
strace.store_sampling_state(sampling_state)
437466
# yield the state of all chains in parallel
438467
yield i
439468
except KeyboardInterrupt:

0 commit comments

Comments
 (0)