@@ -252,9 +252,9 @@ def sample(
252252 draws = 1000 ,
253253 step = None ,
254254 init = "auto" ,
255- n_init = 200000 ,
255+ n_init = 200_000 ,
256256 start = None ,
257- trace = None ,
257+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
258258 chain_idx = 0 ,
259259 chains = None ,
260260 cores = None ,
@@ -296,10 +296,9 @@ def sample(
296296 Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
297297 (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
298298 overwrite the default.
299- trace : backend, list, or MultiTrace
300- This should be a backend instance, a list of variables to track, or a MultiTrace object
301- with past values. If a MultiTrace object is given, it must contain samples for the chain
302- number ``chain``. If None or a list of variables, the NDArray backend is used.
299+ trace : backend or list
300+ This should be a backend instance, or a list of variables to track.
301+ If None or a list of variables, the NDArray backend is used.
303302 chain_idx : int
304303 Chain number used to store sample in backend. If ``chains`` is greater than one, chain
305304 numbers will start here.
@@ -813,7 +812,7 @@ def _sample(
813812 start ,
814813 draws : int ,
815814 step = None ,
816- trace = None ,
815+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
817816 tune = None ,
818817 model : Optional [Model ] = None ,
819818 callback = None ,
@@ -839,10 +838,9 @@ def _sample(
839838 The number of samples to draw
840839 step : function
841840 Step function
842- trace : backend, list, or MultiTrace
843- This should be a backend instance, a list of variables to track, or a MultiTrace object
844- with past values. If a MultiTrace object is given, it must contain samples for the chain
845- number ``chain``. If None or a list of variables, the NDArray backend is used.
841+ trace : backend or list
842+ This should be a backend instance, or a list of variables to track.
843+ If None or a list of variables, the NDArray backend is used.
846844 tune : int, optional
847845 Number of iterations to tune, if applicable (defaults to None)
848846 model : Model (optional if in ``with`` context)
@@ -899,10 +897,9 @@ def iter_sample(
899897 start : dict
900898 Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
901899 there is a trace provided and model.initial_point if not (defaults to empty dict)
902- trace : backend, list, or MultiTrace
903- This should be a backend instance, a list of variables to track, or a MultiTrace object
904- with past values. If a MultiTrace object is given, it must contain samples for the chain
905- number ``chain``. If None or a list of variables, the NDArray backend is used.
900+ trace : backend or list
901+ This should be a backend instance, or a list of variables to track.
902+ If None or a list of variables, the NDArray backend is used.
906903 chain : int, optional
907904 Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
908905 will start here.
@@ -939,7 +936,7 @@ def _iter_sample(
939936 draws ,
940937 step ,
941938 start = None ,
942- trace = None ,
939+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
943940 chain = 0 ,
944941 tune = None ,
945942 model = None ,
@@ -955,12 +952,10 @@ def _iter_sample(
955952 step : function
956953 Step function
957954 start : dict, optional
958- Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
959- there is a trace provided and model.initial_point if not (defaults to empty dict)
960- trace : backend, list, MultiTrace, or None
961- This should be a backend instance, a list of variables to track, or a MultiTrace object
962- with past values. If a MultiTrace object is given, it must contain samples for the chain
963- number ``chain``. If None or a list of variables, the NDArray backend is used.
955+ Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
956+ trace : backend or list
957+ This should be a backend instance, or a list of variables to track.
958+ If None or a list of variables, the NDArray backend is used.
964959 chain : int, optional
965960 Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
966961 will start here.
@@ -986,12 +981,9 @@ def _iter_sample(
986981 if start is None :
987982 start = {}
988983
989- strace = _choose_backend (trace , chain , model = model )
984+ strace = _choose_backend (trace , model = model )
990985
991- if len (strace ) > 0 :
992- model .update_start_vals (start , strace .point (- 1 ))
993- else :
994- model .update_start_vals (start , model .initial_point )
986+ model .update_start_vals (start , model .initial_point )
995987
996988 try :
997989 step = CompoundStep (step )
@@ -1258,7 +1250,7 @@ def _prepare_iter_population(
12581250 # 5. a PopulationStepper is configured for parallelized stepping
12591251
12601252 # 1. prepare a BaseTrace for each chain
1261- traces = [_choose_backend (None , chain , model = model ) for chain in chains ]
1253+ traces = [_choose_backend (None , model = model ) for chain in chains ]
12621254 for c , strace in enumerate (traces ):
12631255 # initialize the trace size and variable transforms
12641256 if len (strace ) > 0 :
@@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
13611353 steppers [c ].report ._finalize (strace )
13621354
13631355
1364- def _choose_backend (trace , chain , ** kwds ) -> Backend :
1356+ def _choose_backend (trace : Optional [ Union [ BaseTrace , List [ str ]]] , ** kwds ) -> Backend :
13651357 """Selects or creates a NDArray trace backend for a particular chain.
13661358
13671359 Parameters
13681360 ----------
1369- trace : BaseTrace, list, MultiTrace, or None
1370- This should be a BaseTrace, list of variables to track,
1371- or a MultiTrace object with past values.
1372- If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
1361+ trace : BaseTrace, list, or None
1362+ This should be a BaseTrace, or list of variables to track.
13731363 If None or a list of variables, the NDArray backend is used.
1374- chain : int
1375- Number of the chain of interest.
13761364 **kwds :
13771365 keyword arguments to forward to the backend creation
13781366
13791367 Returns
13801368 -------
13811369 trace : BaseTrace
1382- A trace object for the selected chain
1370+ The incoming, or a brand new trace object.
13831371 """
1372+ if isinstance (trace , BaseTrace ) and len (trace ) > 0 :
1373+ raise ValueError ("Continuation of traces is no longer supported." )
1374+ if isinstance (trace , MultiTrace ):
1375+ raise ValueError ("Starting from existing MultiTrace objects is no longer supported." )
1376+
13841377 if isinstance (trace , BaseTrace ):
13851378 return trace
1386- if isinstance (trace , MultiTrace ):
1387- return trace ._straces [chain ]
13881379 if trace is None :
13891380 return NDArray (** kwds )
13901381
@@ -1401,7 +1392,7 @@ def _mp_sample(
14011392 random_seed : list ,
14021393 start : list ,
14031394 progressbar = True ,
1404- trace = None ,
1395+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
14051396 model = None ,
14061397 callback = None ,
14071398 discard_tuned_samples = True ,
@@ -1430,10 +1421,9 @@ def _mp_sample(
14301421 Starting points for each chain.
14311422 progressbar : bool
14321423 Whether or not to display a progress bar in the command line.
1433- trace : BaseTrace, list, MultiTrace or None
1434- This should be a backend instance, a list of variables to track, or a MultiTrace object
1435- with past values. If a MultiTrace object is given, it must contain samples for the chain
1436- number ``chain``. If None or a list of variables, the NDArray backend is used.
1424+ trace : BaseTrace, list, or None
1425+ This should be a backend instance, or a list of variables to track
1426+ If None or a list of variables, the NDArray backend is used.
14371427 model : Model (optional if in ``with`` context)
14381428 callback : Callable
14391429 A function which gets called for every sample from the trace of a chain. The function is
@@ -1455,10 +1445,10 @@ def _mp_sample(
14551445 traces = []
14561446 for idx in range (chain , chain + chains ):
14571447 if trace is not None :
1458- strace = _choose_backend (copy (trace ), idx , model = model )
1448+ strace = _choose_backend (copy (trace ), model = model )
14591449 else :
1460- strace = _choose_backend (None , idx , model = model )
1461- # for user supply start value, fill-in missing value if the supplied
1450+ strace = _choose_backend (None , model = model )
1451+ # for user supplied start value, fill-in missing value if the supplied
14621452 # dict does not contain all parameters
14631453 model .update_start_vals (start [idx - chain ], model .initial_point )
14641454 if step .generates_stats and strace .supports_sampler_stats :
0 commit comments