@@ -252,9 +252,9 @@ def sample(
252
252
draws = 1000 ,
253
253
step = None ,
254
254
init = "auto" ,
255
- n_init = 200000 ,
255
+ n_init = 200_000 ,
256
256
start = None ,
257
- trace = None ,
257
+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
258
258
chain_idx = 0 ,
259
259
chains = None ,
260
260
cores = None ,
@@ -296,10 +296,9 @@ def sample(
296
296
Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
297
297
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
298
298
overwrite the default.
299
- trace : backend, list, or MultiTrace
300
- This should be a backend instance, a list of variables to track, or a MultiTrace object
301
- with past values. If a MultiTrace object is given, it must contain samples for the chain
302
- number ``chain``. If None or a list of variables, the NDArray backend is used.
299
+ trace : backend or list
300
+ This should be a backend instance, or a list of variables to track.
301
+ If None or a list of variables, the NDArray backend is used.
303
302
chain_idx : int
304
303
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
305
304
numbers will start here.
@@ -813,7 +812,7 @@ def _sample(
813
812
start ,
814
813
draws : int ,
815
814
step = None ,
816
- trace = None ,
815
+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
817
816
tune = None ,
818
817
model : Optional [Model ] = None ,
819
818
callback = None ,
@@ -839,10 +838,9 @@ def _sample(
839
838
The number of samples to draw
840
839
step : function
841
840
Step function
842
- trace : backend, list, or MultiTrace
843
- This should be a backend instance, a list of variables to track, or a MultiTrace object
844
- with past values. If a MultiTrace object is given, it must contain samples for the chain
845
- number ``chain``. If None or a list of variables, the NDArray backend is used.
841
+ trace : backend or list
842
+ This should be a backend instance, or a list of variables to track.
843
+ If None or a list of variables, the NDArray backend is used.
846
844
tune : int, optional
847
845
Number of iterations to tune, if applicable (defaults to None)
848
846
model : Model (optional if in ``with`` context)
@@ -899,10 +897,9 @@ def iter_sample(
899
897
start : dict
900
898
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
901
899
there is a trace provided and model.initial_point if not (defaults to empty dict)
902
- trace : backend, list, or MultiTrace
903
- This should be a backend instance, a list of variables to track, or a MultiTrace object
904
- with past values. If a MultiTrace object is given, it must contain samples for the chain
905
- number ``chain``. If None or a list of variables, the NDArray backend is used.
900
+ trace : backend or list
901
+ This should be a backend instance, or a list of variables to track.
902
+ If None or a list of variables, the NDArray backend is used.
906
903
chain : int, optional
907
904
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
908
905
will start here.
@@ -939,7 +936,7 @@ def _iter_sample(
939
936
draws ,
940
937
step ,
941
938
start = None ,
942
- trace = None ,
939
+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
943
940
chain = 0 ,
944
941
tune = None ,
945
942
model = None ,
@@ -955,12 +952,10 @@ def _iter_sample(
955
952
step : function
956
953
Step function
957
954
start : dict, optional
958
- Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
959
- there is a trace provided and model.initial_point if not (defaults to empty dict)
960
- trace : backend, list, MultiTrace, or None
961
- This should be a backend instance, a list of variables to track, or a MultiTrace object
962
- with past values. If a MultiTrace object is given, it must contain samples for the chain
963
- number ``chain``. If None or a list of variables, the NDArray backend is used.
955
+ Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
956
+ trace : backend or list
957
+ This should be a backend instance, or a list of variables to track.
958
+ If None or a list of variables, the NDArray backend is used.
964
959
chain : int, optional
965
960
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
966
961
will start here.
@@ -986,12 +981,9 @@ def _iter_sample(
986
981
if start is None :
987
982
start = {}
988
983
989
- strace = _choose_backend (trace , chain , model = model )
984
+ strace = _choose_backend (trace , model = model )
990
985
991
- if len (strace ) > 0 :
992
- model .update_start_vals (start , strace .point (- 1 ))
993
- else :
994
- model .update_start_vals (start , model .initial_point )
986
+ model .update_start_vals (start , model .initial_point )
995
987
996
988
try :
997
989
step = CompoundStep (step )
@@ -1258,7 +1250,7 @@ def _prepare_iter_population(
1258
1250
# 5. a PopulationStepper is configured for parallelized stepping
1259
1251
1260
1252
# 1. prepare a BaseTrace for each chain
1261
- traces = [_choose_backend (None , chain , model = model ) for chain in chains ]
1253
+ traces = [_choose_backend (None , model = model ) for chain in chains ]
1262
1254
for c , strace in enumerate (traces ):
1263
1255
# initialize the trace size and variable transforms
1264
1256
if len (strace ) > 0 :
@@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
1361
1353
steppers [c ].report ._finalize (strace )
1362
1354
1363
1355
1364
- def _choose_backend (trace , chain , ** kwds ) -> Backend :
1356
+ def _choose_backend (trace : Optional [ Union [ BaseTrace , List [ str ]]] , ** kwds ) -> Backend :
1365
1357
"""Selects or creates a NDArray trace backend for a particular chain.
1366
1358
1367
1359
Parameters
1368
1360
----------
1369
- trace : BaseTrace, list, MultiTrace, or None
1370
- This should be a BaseTrace, list of variables to track,
1371
- or a MultiTrace object with past values.
1372
- If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
1361
+ trace : BaseTrace, list, or None
1362
+ This should be a BaseTrace, or list of variables to track.
1373
1363
If None or a list of variables, the NDArray backend is used.
1374
- chain : int
1375
- Number of the chain of interest.
1376
1364
**kwds :
1377
1365
keyword arguments to forward to the backend creation
1378
1366
1379
1367
Returns
1380
1368
-------
1381
1369
trace : BaseTrace
1382
- A trace object for the selected chain
1370
+ The incoming, or a brand new trace object.
1383
1371
"""
1372
+ if isinstance (trace , BaseTrace ) and len (trace ) > 0 :
1373
+ raise ValueError ("Continuation of traces is no longer supported." )
1374
+ if isinstance (trace , MultiTrace ):
1375
+ raise ValueError ("Starting from existing MultiTrace objects is no longer supported." )
1376
+
1384
1377
if isinstance (trace , BaseTrace ):
1385
1378
return trace
1386
- if isinstance (trace , MultiTrace ):
1387
- return trace ._straces [chain ]
1388
1379
if trace is None :
1389
1380
return NDArray (** kwds )
1390
1381
@@ -1401,7 +1392,7 @@ def _mp_sample(
1401
1392
random_seed : list ,
1402
1393
start : list ,
1403
1394
progressbar = True ,
1404
- trace = None ,
1395
+ trace : Optional [ Union [ BaseTrace , List [ str ]]] = None ,
1405
1396
model = None ,
1406
1397
callback = None ,
1407
1398
discard_tuned_samples = True ,
@@ -1430,10 +1421,9 @@ def _mp_sample(
1430
1421
Starting points for each chain.
1431
1422
progressbar : bool
1432
1423
Whether or not to display a progress bar in the command line.
1433
- trace : BaseTrace, list, MultiTrace or None
1434
- This should be a backend instance, a list of variables to track, or a MultiTrace object
1435
- with past values. If a MultiTrace object is given, it must contain samples for the chain
1436
- number ``chain``. If None or a list of variables, the NDArray backend is used.
1424
+ trace : BaseTrace, list, or None
1425
+ This should be a backend instance, or a list of variables to track
1426
+ If None or a list of variables, the NDArray backend is used.
1437
1427
model : Model (optional if in ``with`` context)
1438
1428
callback : Callable
1439
1429
A function which gets called for every sample from the trace of a chain. The function is
@@ -1455,10 +1445,10 @@ def _mp_sample(
1455
1445
traces = []
1456
1446
for idx in range (chain , chain + chains ):
1457
1447
if trace is not None :
1458
- strace = _choose_backend (copy (trace ), idx , model = model )
1448
+ strace = _choose_backend (copy (trace ), model = model )
1459
1449
else :
1460
- strace = _choose_backend (None , idx , model = model )
1461
- # for user supply start value, fill-in missing value if the supplied
1450
+ strace = _choose_backend (None , model = model )
1451
+ # for user supplied start value, fill-in missing value if the supplied
1462
1452
# dict does not contain all parameters
1463
1453
model .update_start_vals (start [idx - chain ], model .initial_point )
1464
1454
if step .generates_stats and strace .supports_sampler_stats :
0 commit comments