Skip to content

Commit f73e933

Browse files
michaelosthegetwiecki
authored andcommitted
Remove support for continuation of traces
1 parent 53e572c commit f73e933

File tree

2 files changed

+55
-48
lines changed

2 files changed

+55
-48
lines changed

pymc3/sampling.py

+36-46
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def sample(
252252
draws=1000,
253253
step=None,
254254
init="auto",
255-
n_init=200000,
255+
n_init=200_000,
256256
start=None,
257-
trace=None,
257+
trace: Optional[Union[BaseTrace, List[str]]] = None,
258258
chain_idx=0,
259259
chains=None,
260260
cores=None,
@@ -296,10 +296,9 @@ def sample(
296296
Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
297297
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
298298
overwrite the default.
299-
trace : backend, list, or MultiTrace
300-
This should be a backend instance, a list of variables to track, or a MultiTrace object
301-
with past values. If a MultiTrace object is given, it must contain samples for the chain
302-
number ``chain``. If None or a list of variables, the NDArray backend is used.
299+
trace : backend or list
300+
This should be a backend instance, or a list of variables to track.
301+
If None or a list of variables, the NDArray backend is used.
303302
chain_idx : int
304303
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
305304
numbers will start here.
@@ -813,7 +812,7 @@ def _sample(
813812
start,
814813
draws: int,
815814
step=None,
816-
trace=None,
815+
trace: Optional[Union[BaseTrace, List[str]]] = None,
817816
tune=None,
818817
model: Optional[Model] = None,
819818
callback=None,
@@ -839,10 +838,9 @@ def _sample(
839838
The number of samples to draw
840839
step : function
841840
Step function
842-
trace : backend, list, or MultiTrace
843-
This should be a backend instance, a list of variables to track, or a MultiTrace object
844-
with past values. If a MultiTrace object is given, it must contain samples for the chain
845-
number ``chain``. If None or a list of variables, the NDArray backend is used.
841+
trace : backend or list
842+
This should be a backend instance, or a list of variables to track.
843+
If None or a list of variables, the NDArray backend is used.
846844
tune : int, optional
847845
Number of iterations to tune, if applicable (defaults to None)
848846
model : Model (optional if in ``with`` context)
@@ -899,10 +897,9 @@ def iter_sample(
899897
start : dict
900898
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
901899
there is a trace provided and model.initial_point if not (defaults to empty dict)
902-
trace : backend, list, or MultiTrace
903-
This should be a backend instance, a list of variables to track, or a MultiTrace object
904-
with past values. If a MultiTrace object is given, it must contain samples for the chain
905-
number ``chain``. If None or a list of variables, the NDArray backend is used.
900+
trace : backend or list
901+
This should be a backend instance, or a list of variables to track.
902+
If None or a list of variables, the NDArray backend is used.
906903
chain : int, optional
907904
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
908905
will start here.
@@ -939,7 +936,7 @@ def _iter_sample(
939936
draws,
940937
step,
941938
start=None,
942-
trace=None,
939+
trace: Optional[Union[BaseTrace, List[str]]] = None,
943940
chain=0,
944941
tune=None,
945942
model=None,
@@ -955,12 +952,10 @@ def _iter_sample(
955952
step : function
956953
Step function
957954
start : dict, optional
958-
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
959-
there is a trace provided and model.initial_point if not (defaults to empty dict)
960-
trace : backend, list, MultiTrace, or None
961-
This should be a backend instance, a list of variables to track, or a MultiTrace object
962-
with past values. If a MultiTrace object is given, it must contain samples for the chain
963-
number ``chain``. If None or a list of variables, the NDArray backend is used.
955+
Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
956+
trace : backend or list
957+
This should be a backend instance, or a list of variables to track.
958+
If None or a list of variables, the NDArray backend is used.
964959
chain : int, optional
965960
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
966961
will start here.
@@ -986,12 +981,9 @@ def _iter_sample(
986981
if start is None:
987982
start = {}
988983

989-
strace = _choose_backend(trace, chain, model=model)
984+
strace = _choose_backend(trace, model=model)
990985

991-
if len(strace) > 0:
992-
model.update_start_vals(start, strace.point(-1))
993-
else:
994-
model.update_start_vals(start, model.initial_point)
986+
model.update_start_vals(start, model.initial_point)
995987

996988
try:
997989
step = CompoundStep(step)
@@ -1258,7 +1250,7 @@ def _prepare_iter_population(
12581250
# 5. a PopulationStepper is configured for parallelized stepping
12591251

12601252
# 1. prepare a BaseTrace for each chain
1261-
traces = [_choose_backend(None, chain, model=model) for chain in chains]
1253+
traces = [_choose_backend(None, model=model) for chain in chains]
12621254
for c, strace in enumerate(traces):
12631255
# initialize the trace size and variable transforms
12641256
if len(strace) > 0:
@@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
13611353
steppers[c].report._finalize(strace)
13621354

13631355

1364-
def _choose_backend(trace, chain, **kwds) -> Backend:
1356+
def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Backend:
13651357
"""Selects or creates a NDArray trace backend for a particular chain.
13661358
13671359
Parameters
13681360
----------
1369-
trace : BaseTrace, list, MultiTrace, or None
1370-
This should be a BaseTrace, list of variables to track,
1371-
or a MultiTrace object with past values.
1372-
If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
1361+
trace : BaseTrace, list, or None
1362+
This should be a BaseTrace, or list of variables to track.
13731363
If None or a list of variables, the NDArray backend is used.
1374-
chain : int
1375-
Number of the chain of interest.
13761364
**kwds :
13771365
keyword arguments to forward to the backend creation
13781366
13791367
Returns
13801368
-------
13811369
trace : BaseTrace
1382-
A trace object for the selected chain
1370+
The incoming, or a brand new trace object.
13831371
"""
1372+
if isinstance(trace, BaseTrace) and len(trace) > 0:
1373+
raise ValueError("Continuation of traces is no longer supported.")
1374+
if isinstance(trace, MultiTrace):
1375+
raise ValueError("Starting from existing MultiTrace objects is no longer supported.")
1376+
13841377
if isinstance(trace, BaseTrace):
13851378
return trace
1386-
if isinstance(trace, MultiTrace):
1387-
return trace._straces[chain]
13881379
if trace is None:
13891380
return NDArray(**kwds)
13901381

@@ -1401,7 +1392,7 @@ def _mp_sample(
14011392
random_seed: list,
14021393
start: list,
14031394
progressbar=True,
1404-
trace=None,
1395+
trace: Optional[Union[BaseTrace, List[str]]] = None,
14051396
model=None,
14061397
callback=None,
14071398
discard_tuned_samples=True,
@@ -1430,10 +1421,9 @@ def _mp_sample(
14301421
Starting points for each chain.
14311422
progressbar : bool
14321423
Whether or not to display a progress bar in the command line.
1433-
trace : BaseTrace, list, MultiTrace or None
1434-
This should be a backend instance, a list of variables to track, or a MultiTrace object
1435-
with past values. If a MultiTrace object is given, it must contain samples for the chain
1436-
number ``chain``. If None or a list of variables, the NDArray backend is used.
1424+
trace : BaseTrace, list, or None
1425+
This should be a backend instance, or a list of variables to track
1426+
If None or a list of variables, the NDArray backend is used.
14371427
model : Model (optional if in ``with`` context)
14381428
callback : Callable
14391429
A function which gets called for every sample from the trace of a chain. The function is
@@ -1455,10 +1445,10 @@ def _mp_sample(
14551445
traces = []
14561446
for idx in range(chain, chain + chains):
14571447
if trace is not None:
1458-
strace = _choose_backend(copy(trace), idx, model=model)
1448+
strace = _choose_backend(copy(trace), model=model)
14591449
else:
1460-
strace = _choose_backend(None, idx, model=model)
1461-
# for user supply start value, fill-in missing value if the supplied
1450+
strace = _choose_backend(None, model=model)
1451+
# for user supplied start value, fill-in missing value if the supplied
14621452
# dict does not contain all parameters
14631453
model.update_start_vals(start[idx - chain], model.initial_point)
14641454
if step.generates_stats and strace.supports_sampler_stats:

pymc3/tests/test_sampling.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import pymc3 as pm
3333

3434
from pymc3.aesaraf import compile_rv_inplace
35+
from pymc3.backends.base import MultiTrace
3536
from pymc3.backends.ndarray import NDArray
3637
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
3738
from pymc3.tests.helpers import SeededTest
@@ -438,14 +439,30 @@ def test_constant_named(self):
438439
class TestChooseBackend:
439440
def test_choose_backend_none(self):
440441
with mock.patch("pymc3.sampling.NDArray") as nd:
441-
pm.sampling._choose_backend(None, "chain")
442+
pm.sampling._choose_backend(None)
442443
assert nd.called
443444

444445
def test_choose_backend_list_of_variables(self):
445446
with mock.patch("pymc3.sampling.NDArray") as nd:
446-
pm.sampling._choose_backend(["var1", "var2"], "chain")
447+
pm.sampling._choose_backend(["var1", "var2"])
447448
nd.assert_called_with(vars=["var1", "var2"])
448449

450+
def test_errors_and_warnings(self):
451+
with pm.Model():
452+
A = pm.Normal("A")
453+
B = pm.Uniform("B")
454+
strace = pm.sampling.NDArray(vars=[A, B])
455+
strace.setup(10, 0)
456+
457+
with pytest.raises(ValueError, match="from existing MultiTrace"):
458+
pm.sampling._choose_backend(trace=MultiTrace([strace]))
459+
460+
strace.record({"A": 2, "B_interval__": 0.1})
461+
assert len(strace) == 1
462+
with pytest.raises(ValueError, match="Continuation of traces"):
463+
pm.sampling._choose_backend(trace=strace)
464+
pass
465+
449466

450467
class TestSamplePPC(SeededTest):
451468
def test_normal_scalar(self):

0 commit comments

Comments
 (0)