Skip to content

Commit 147b92e

Browse files
committed
Integrate ZarrTrace into pymc.sample
1 parent a773405 commit 147b92e

File tree

3 files changed

+197
-10
lines changed

3 files changed

+197
-10
lines changed

pymc/backends/__init__.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
7373
from pymc.backends.base import BaseTrace, IBaseTrace
7474
from pymc.backends.ndarray import NDArray
75+
from pymc.backends.zarr import ZarrTrace
7576
from pymc.blocking import PointType
7677
from pymc.model import Model
7778
from pymc.step_methods.compound import BlockedStep, CompoundStep
@@ -120,15 +121,27 @@ def _init_trace(
120121

121122
def init_traces(
122123
*,
123-
backend: TraceOrBackend | None,
124+
backend: TraceOrBackend | ZarrTrace | None,
124125
chains: int,
125126
expected_length: int,
126127
step: BlockedStep | CompoundStep,
127128
initial_point: PointType,
128129
model: Model,
129130
trace_vars: list[TensorVariable] | None = None,
131+
tune: int = 0,
130132
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
131133
"""Initialize a trace recorder for each chain."""
134+
if isinstance(backend, ZarrTrace):
135+
backend.init_trace(
136+
chains=chains,
137+
draws=expected_length - tune,
138+
tune=tune,
139+
step=step,
140+
model=model,
141+
vars=trace_vars,
142+
test_point=initial_point,
143+
)
144+
return None, backend.straces
132145
if HAS_MCB and isinstance(backend, Backend):
133146
return init_chain_adapters(
134147
backend=backend,

pymc/sampling/mcmc.py

+72-9
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
find_observations,
5151
)
5252
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
53+
from pymc.backends.zarr import ZarrTrace
5354
from pymc.blocking import DictToArrayBijection
5455
from pymc.exceptions import SamplingError
5556
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
@@ -503,7 +504,7 @@ def sample(
503504
model: Model | None = None,
504505
compile_kwargs: dict | None = None,
505506
**kwargs,
506-
) -> InferenceData | MultiTrace:
507+
) -> InferenceData | MultiTrace | ZarrTrace:
507508
r"""Draw samples from the posterior using the given step methods.
508509
509510
Multiple step methods are supported via compound step methods.
@@ -570,7 +571,13 @@ def sample(
570571
Number of iterations of initializer. Only works for 'ADVI' init methods.
571572
trace : backend, optional
572573
A backend instance or None.
573-
If None, the NDArray backend is used.
574+
If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects
575+
is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance,
576+
the drawn samples will be written onto the desired storage while sampling is
577+
on-going. This means sampling runs that, for whatever reason, die in the middle
578+
of their execution will write the partial results onto the storage. If the
579+
storage persist on disk, these results should be available even after a server
580+
crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information.
574581
discard_tuned_samples : bool
575582
Whether to discard posterior samples of the tune interval.
576583
compute_convergence_checks : bool, default=True
@@ -607,8 +614,12 @@ def sample(
607614
608615
Returns
609616
-------
610-
trace : pymc.backends.base.MultiTrace or arviz.InferenceData
611-
A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples.
617+
trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
618+
A ``MultiTrace``, :class:`~arviz.InferenceData` or
619+
:class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A
620+
``ZarrTrace`` is only returned if the supplied ``trace`` argument is a
621+
``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
622+
the benefits this backend provides.
612623
613624
Notes
614625
-----
@@ -741,7 +752,7 @@ def joined_blas_limiter():
741752
rngs = get_random_generator(random_seed).spawn(chains)
742753
random_seed_list = [rng.integers(2**30) for rng in rngs]
743754

744-
if not discard_tuned_samples and not return_inferencedata:
755+
if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace):
745756
warnings.warn(
746757
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
747758
" complications in your downstream analysis. Please consider to switch to `InferenceData`:\n"
@@ -852,6 +863,7 @@ def joined_blas_limiter():
852863
trace_vars=trace_vars,
853864
initial_point=initial_points[0],
854865
model=model,
866+
tune=tune,
855867
)
856868

857869
sample_args = {
@@ -934,7 +946,7 @@ def joined_blas_limiter():
934946
# into a function to make it easier to test and refactor.
935947
return _sample_return(
936948
run=run,
937-
traces=traces,
949+
traces=trace if isinstance(trace, ZarrTrace) else traces,
938950
tune=tune,
939951
t_sampling=t_sampling,
940952
discard_tuned_samples=discard_tuned_samples,
@@ -949,7 +961,7 @@ def joined_blas_limiter():
949961
def _sample_return(
950962
*,
951963
run: RunType | None,
952-
traces: Sequence[IBaseTrace],
964+
traces: Sequence[IBaseTrace] | ZarrTrace,
953965
tune: int,
954966
t_sampling: float,
955967
discard_tuned_samples: bool,
@@ -958,18 +970,69 @@ def _sample_return(
958970
keep_warning_stat: bool,
959971
idata_kwargs: dict[str, Any],
960972
model: Model,
961-
) -> InferenceData | MultiTrace:
973+
) -> InferenceData | MultiTrace | ZarrTrace:
962974
"""Pick/slice chains, run diagnostics and convert to the desired return type.
963975
964976
Final step of `pm.sampler`.
965977
"""
978+
if isinstance(traces, ZarrTrace):
979+
# Split warmup from posterior samples
980+
traces.split_warmup_groups()
981+
982+
# Set sampling time
983+
traces.sampling_time = t_sampling
984+
985+
# Compute number of actual draws per chain
986+
total_draws_per_chain = traces._sampling_state.draw_idx[:]
987+
n_chains = len(traces.straces)
988+
desired_tune = traces.tuning_steps
989+
desired_draw = len(traces.posterior.draw)
990+
tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune)
991+
draws_per_chain = total_draws_per_chain - tuning_steps_per_chain
992+
993+
total_n_tune = tuning_steps_per_chain.sum()
994+
total_draws = draws_per_chain.sum()
995+
996+
_log.info(
997+
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
998+
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
999+
f"took {t_sampling:.0f} seconds."
1000+
)
1001+
1002+
if compute_convergence_checks or return_inferencedata:
1003+
idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples)
1004+
log_likelihood = idata_kwargs.pop("log_likelihood", False)
1005+
if log_likelihood:
1006+
from pymc.stats.log_density import compute_log_likelihood
1007+
1008+
idata = compute_log_likelihood(
1009+
idata,
1010+
var_names=None if log_likelihood is True else log_likelihood,
1011+
extend_inferencedata=True,
1012+
model=model,
1013+
sample_dims=["chain", "draw"],
1014+
progressbar=False,
1015+
)
1016+
if compute_convergence_checks:
1017+
warns = run_convergence_checks(idata, model)
1018+
for warn in warns:
1019+
traces._sampling_state.global_warnings.append(np.array([warn]))
1020+
log_warnings(warns)
1021+
1022+
if return_inferencedata:
1023+
# By default we drop the "warning" stat which contains `SamplerWarning`
1024+
# objects that can not be stored with `.to_netcdf()`.
1025+
if not keep_warning_stat:
1026+
return drop_warning_stat(idata)
1027+
return idata
1028+
return traces
1029+
9661030
# Pick and slice chains to keep the maximum number of samples
9671031
if discard_tuned_samples:
9681032
traces, length = _choose_chains(traces, tune)
9691033
else:
9701034
traces, length = _choose_chains(traces, 0)
9711035
mtrace = MultiTrace(traces)[:length]
972-
9731036
# count the number of tune/draw iterations that happened
9741037
# ideally via the "tune" statistic, but not all samplers record it!
9751038
if "tune" in mtrace.stat_names:

tests/backends/test_zarr.py

+111
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pytest
2020
import zarr
2121

22+
from arviz import InferenceData
23+
2224
import pymc as pm
2325

2426
from pymc.backends.zarr import ZarrTrace
@@ -357,3 +359,112 @@ def test_split_warmup(tune, model, model_step, include_transformed):
357359
if len(dims) >= 2 and dims[1] == "draw":
358360
assert sample_stats_array.shape[1] == draws
359361
assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune
362+
363+
364+
@pytest.fixture(scope="function", params=[True, False])
365+
def discard_tuned_samples(request):
366+
return request.param
367+
368+
369+
@pytest.fixture(scope="function", params=[True, False])
370+
def return_inferencedata(request):
371+
return request.param
372+
373+
374+
@pytest.fixture(scope="function", params=[True, False])
375+
def keep_warning_stat(request):
376+
return request.param
377+
378+
379+
@pytest.fixture(scope="function", params=[True, False])
380+
def parallel(request):
381+
return request.param
382+
383+
384+
@pytest.fixture(scope="function", params=[True, False])
385+
def log_likelihood(request):
386+
return request.param
387+
388+
389+
def test_sample(
390+
model,
391+
model_step,
392+
include_transformed,
393+
discard_tuned_samples,
394+
return_inferencedata,
395+
keep_warning_stat,
396+
parallel,
397+
log_likelihood,
398+
draws_per_chunk,
399+
):
400+
if not return_inferencedata and not log_likelihood:
401+
pytest.skip(
402+
reason="log_likelihood is only computed if an inference data object is returned"
403+
)
404+
store = zarr.MemoryStore()
405+
trace = ZarrTrace(
406+
store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
407+
)
408+
tune = 2
409+
draws = 3
410+
if parallel:
411+
chains = 2
412+
cores = 2
413+
else:
414+
chains = 1
415+
cores = 1
416+
with model:
417+
out_trace = pm.sample(
418+
draws=draws,
419+
tune=tune,
420+
chains=chains,
421+
cores=cores,
422+
trace=trace,
423+
step=model_step,
424+
discard_tuned_samples=discard_tuned_samples,
425+
return_inferencedata=return_inferencedata,
426+
keep_warning_stat=keep_warning_stat,
427+
idata_kwargs={"log_likelihood": log_likelihood},
428+
)
429+
430+
if not return_inferencedata:
431+
assert isinstance(out_trace, ZarrTrace)
432+
assert out_trace.root.store is trace.root.store
433+
else:
434+
assert isinstance(out_trace, InferenceData)
435+
436+
expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"}
437+
if include_transformed:
438+
expected_groups |= {"unconstrained_posterior"}
439+
if not return_inferencedata or not discard_tuned_samples:
440+
expected_groups |= {"warmup_posterior", "warmup_sample_stats"}
441+
if include_transformed:
442+
expected_groups |= {"warmup_unconstrained_posterior"}
443+
if not return_inferencedata:
444+
expected_groups |= {"_sampling_state"}
445+
elif log_likelihood:
446+
expected_groups |= {"log_likelihood"}
447+
assert set(out_trace.groups()) == expected_groups
448+
449+
if return_inferencedata:
450+
warning_stat = (
451+
"sampler_1__warning" if isinstance(model_step, CompoundStep) else "sampler_0__warning"
452+
)
453+
if keep_warning_stat:
454+
assert warning_stat in out_trace.sample_stats
455+
else:
456+
assert warning_stat not in out_trace.sample_stats
457+
458+
# Assert that all variables have non empty samples (not NaNs)
459+
if return_inferencedata:
460+
assert all(
461+
(not np.any(np.isnan(v))) and v.shape[:2] == (chains, draws)
462+
for v in out_trace.posterior.data_vars.values()
463+
)
464+
else:
465+
dimensions = {*model.coords, "a_dim_0", "a_dim_1", "chain", "draw"}
466+
assert all(
467+
(not np.any(np.isnan(v[:]))) and v.shape[:2] == (chains, draws)
468+
for name, v in out_trace.posterior.arrays()
469+
if name not in dimensions
470+
)

0 commit comments

Comments
 (0)