50
50
find_observations ,
51
51
)
52
52
from pymc .backends .base import IBaseTrace , MultiTrace , _choose_chains
53
+ from pymc .backends .zarr import ZarrTrace
53
54
from pymc .blocking import DictToArrayBijection
54
55
from pymc .exceptions import SamplingError
55
56
from pymc .initial_point import PointType , StartDict , make_initial_point_fns_per_chain
@@ -503,7 +504,7 @@ def sample(
503
504
model : Model | None = None ,
504
505
compile_kwargs : dict | None = None ,
505
506
** kwargs ,
506
- ) -> InferenceData | MultiTrace :
507
+ ) -> InferenceData | MultiTrace | ZarrTrace :
507
508
r"""Draw samples from the posterior using the given step methods.
508
509
509
510
Multiple step methods are supported via compound step methods.
@@ -570,7 +571,13 @@ def sample(
570
571
Number of iterations of initializer. Only works for 'ADVI' init methods.
571
572
trace : backend, optional
572
573
A backend instance or None.
573
- If None, the NDArray backend is used.
574
+ If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects
575
+ is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance,
576
+ the drawn samples will be written onto the desired storage while sampling is
577
+ on-going. This means sampling runs that, for whatever reason, die in the middle
578
+ of their execution will write the partial results onto the storage. If the
579
+ storage persist on disk, these results should be available even after a server
580
+ crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information.
574
581
discard_tuned_samples : bool
575
582
Whether to discard posterior samples of the tune interval.
576
583
compute_convergence_checks : bool, default=True
@@ -607,8 +614,12 @@ def sample(
607
614
608
615
Returns
609
616
-------
610
- trace : pymc.backends.base.MultiTrace or arviz.InferenceData
611
- A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples.
617
+ trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
618
+ A ``MultiTrace``, :class:`~arviz.InferenceData` or
619
+ :class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A
620
+ ``ZarrTrace`` is only returned if the supplied ``trace`` argument is a
621
+ ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
622
+ the benefits this backend provides.
612
623
613
624
Notes
614
625
-----
@@ -741,7 +752,7 @@ def joined_blas_limiter():
741
752
rngs = get_random_generator (random_seed ).spawn (chains )
742
753
random_seed_list = [rng .integers (2 ** 30 ) for rng in rngs ]
743
754
744
- if not discard_tuned_samples and not return_inferencedata :
755
+ if not discard_tuned_samples and not return_inferencedata and not isinstance ( trace , ZarrTrace ) :
745
756
warnings .warn (
746
757
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
747
758
" complications in your downstream analysis. Please consider to switch to `InferenceData`:\n "
@@ -852,6 +863,7 @@ def joined_blas_limiter():
852
863
trace_vars = trace_vars ,
853
864
initial_point = initial_points [0 ],
854
865
model = model ,
866
+ tune = tune ,
855
867
)
856
868
857
869
sample_args = {
@@ -934,7 +946,7 @@ def joined_blas_limiter():
934
946
# into a function to make it easier to test and refactor.
935
947
return _sample_return (
936
948
run = run ,
937
- traces = traces ,
949
+ traces = trace if isinstance ( trace , ZarrTrace ) else traces ,
938
950
tune = tune ,
939
951
t_sampling = t_sampling ,
940
952
discard_tuned_samples = discard_tuned_samples ,
@@ -949,7 +961,7 @@ def joined_blas_limiter():
949
961
def _sample_return (
950
962
* ,
951
963
run : RunType | None ,
952
- traces : Sequence [IBaseTrace ],
964
+ traces : Sequence [IBaseTrace ] | ZarrTrace ,
953
965
tune : int ,
954
966
t_sampling : float ,
955
967
discard_tuned_samples : bool ,
@@ -958,18 +970,69 @@ def _sample_return(
958
970
keep_warning_stat : bool ,
959
971
idata_kwargs : dict [str , Any ],
960
972
model : Model ,
961
- ) -> InferenceData | MultiTrace :
973
+ ) -> InferenceData | MultiTrace | ZarrTrace :
962
974
"""Pick/slice chains, run diagnostics and convert to the desired return type.
963
975
964
976
Final step of `pm.sampler`.
965
977
"""
978
+ if isinstance (traces , ZarrTrace ):
979
+ # Split warmup from posterior samples
980
+ traces .split_warmup_groups ()
981
+
982
+ # Set sampling time
983
+ traces .sampling_time = t_sampling
984
+
985
+ # Compute number of actual draws per chain
986
+ total_draws_per_chain = traces ._sampling_state .draw_idx [:]
987
+ n_chains = len (traces .straces )
988
+ desired_tune = traces .tuning_steps
989
+ desired_draw = len (traces .posterior .draw )
990
+ tuning_steps_per_chain = np .clip (total_draws_per_chain , 0 , desired_tune )
991
+ draws_per_chain = total_draws_per_chain - tuning_steps_per_chain
992
+
993
+ total_n_tune = tuning_steps_per_chain .sum ()
994
+ total_draws = draws_per_chain .sum ()
995
+
996
+ _log .info (
997
+ f'Sampling { n_chains } chain{ "s" if n_chains > 1 else "" } for { desired_tune :_d} desired tune and { desired_draw :_d} desired draw iterations '
998
+ f"(Actually sampled { total_n_tune :_d} tune and { total_draws :_d} draws total) "
999
+ f"took { t_sampling :.0f} seconds."
1000
+ )
1001
+
1002
+ if compute_convergence_checks or return_inferencedata :
1003
+ idata = traces .to_inferencedata (save_warmup = not discard_tuned_samples )
1004
+ log_likelihood = idata_kwargs .pop ("log_likelihood" , False )
1005
+ if log_likelihood :
1006
+ from pymc .stats .log_density import compute_log_likelihood
1007
+
1008
+ idata = compute_log_likelihood (
1009
+ idata ,
1010
+ var_names = None if log_likelihood is True else log_likelihood ,
1011
+ extend_inferencedata = True ,
1012
+ model = model ,
1013
+ sample_dims = ["chain" , "draw" ],
1014
+ progressbar = False ,
1015
+ )
1016
+ if compute_convergence_checks :
1017
+ warns = run_convergence_checks (idata , model )
1018
+ for warn in warns :
1019
+ traces ._sampling_state .global_warnings .append (np .array ([warn ]))
1020
+ log_warnings (warns )
1021
+
1022
+ if return_inferencedata :
1023
+ # By default we drop the "warning" stat which contains `SamplerWarning`
1024
+ # objects that can not be stored with `.to_netcdf()`.
1025
+ if not keep_warning_stat :
1026
+ return drop_warning_stat (idata )
1027
+ return idata
1028
+ return traces
1029
+
966
1030
# Pick and slice chains to keep the maximum number of samples
967
1031
if discard_tuned_samples :
968
1032
traces , length = _choose_chains (traces , tune )
969
1033
else :
970
1034
traces , length = _choose_chains (traces , 0 )
971
1035
mtrace = MultiTrace (traces )[:length ]
972
-
973
1036
# count the number of tune/draw iterations that happened
974
1037
# ideally via the "tune" statistic, but not all samplers record it!
975
1038
if "tune" in mtrace .stat_names :
0 commit comments