Skip to content

Commit ca692e9

Browse files
committed
Fix default use of global seeding
1 parent 070700c commit ca692e9

File tree

2 files changed

+41
-28
lines changed

2 files changed

+41
-28
lines changed

pymc/parallel_sampling.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import traceback
2222

2323
from collections import namedtuple
24-
from typing import Dict, List, Sequence
24+
from typing import TYPE_CHECKING, Dict, List, Sequence
2525

2626
import cloudpickle
2727
import numpy as np
@@ -32,6 +32,10 @@
3232
from pymc.blocking import DictToArrayBijection
3333
from pymc.exceptions import SamplingError
3434

35+
# Avoid circular import
36+
if TYPE_CHECKING:
37+
from pymc.sampling import RandomSeed
38+
3539
logger = logging.getLogger("pymc")
3640

3741

@@ -389,7 +393,7 @@ def __init__(
389393
tune: int,
390394
chains: int,
391395
cores: int,
392-
seeds: list,
396+
seeds: Sequence["RandomSeed"],
393397
start_points: Sequence[Dict[str, np.ndarray]],
394398
step_method,
395399
start_chain_num: int = 0,

pymc/sampling.py

+35-26
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
PointList: TypeAlias = List[PointType]
105105
Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray]
106106

107+
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]]
108+
107109
_log = logging.getLogger("pymc")
108110

109111

@@ -437,14 +439,14 @@ def sample(
437439
if random_seed == -1:
438440
random_seed = None
439441
if chains == 1 and isinstance(random_seed, int):
440-
random_seed = [random_seed]
442+
random_seed_list = [random_seed]
441443

442444
if random_seed is None or isinstance(random_seed, int):
443445
if random_seed is not None:
444446
np.random.seed(random_seed)
445-
random_seed = [np.random.randint(2**30) for _ in range(chains)]
447+
random_seed_list = [np.random.randint(2**30) for _ in range(chains)]
446448

447-
if not isinstance(random_seed, abc.Iterable):
449+
if not isinstance(random_seed_list, abc.Iterable):
448450
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
449451

450452
if not discard_tuned_samples and not return_inferencedata:
@@ -490,7 +492,7 @@ def sample(
490492
chains=chains,
491493
n_init=n_init,
492494
model=model,
493-
seeds=random_seed,
495+
seeds=random_seed_list,
494496
progressbar=progressbar,
495497
jitter_max_retries=jitter_max_retries,
496498
tune=tune,
@@ -506,7 +508,7 @@ def sample(
506508
jitter_rvs=filter_rvs_to_jitter(step),
507509
chains=chains,
508510
)
509-
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed)]
511+
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]
510512

511513
# One final check that shapes and logps at the starting points are okay.
512514
for ip in initial_points:
@@ -523,7 +525,6 @@ def sample(
523525
"tune": tune,
524526
"progressbar": progressbar,
525527
"model": model,
526-
"random_seed": random_seed,
527528
"cores": cores,
528529
"callback": callback,
529530
"discard_tuned_samples": discard_tuned_samples,
@@ -542,6 +543,19 @@ def sample(
542543
)
543544

544545
parallel = cores > 1 and chains > 1 and not has_population_samplers
546+
# At some point it was decided that PyMC should not set a global seed by default,
547+
# unless the user specified a seed. This is a symptom of the fact that PyMC samplers
548+
# are built around global seeding. This branch makes sure we maintain this unspoken
549+
# rule. See https://github.com/pymc-devs/pymc/pull/1395.
550+
if parallel:
551+
# For parallel sampling we can pass the list of random seeds directly, as
552+
# global seeding will only be called inside each process
553+
sample_args["random_seed"] = random_seed_list
554+
else:
555+
# We pass None if the original random seed was None. The single core sampler
556+
# methods will only set a global seed when it is not None.
557+
sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list
558+
545559
t_start = time.time()
546560
if parallel:
547561
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
@@ -674,7 +688,7 @@ def _sample_many(
674688
chain: int,
675689
chains: int,
676690
start: Sequence[PointType],
677-
random_seed: list,
691+
random_seed: Optional[Sequence[RandomSeed]],
678692
step,
679693
callback=None,
680694
**kwargs,
@@ -691,7 +705,7 @@ def _sample_many(
691705
Total number of chains to sample.
692706
start: list
693707
Starting points for each chain
694-
random_seed: list
708+
random_seed: list of random seeds, optional
695709
A list of seeds, one for each chain
696710
step: function
697711
Step function
@@ -708,7 +722,7 @@ def _sample_many(
708722
chain=chain + i,
709723
start=start[i],
710724
step=step,
711-
random_seed=random_seed[i],
725+
random_seed=None if random_seed is None else random_seed[i],
712726
callback=callback,
713727
**kwargs,
714728
)
@@ -731,7 +745,7 @@ def _sample_population(
731745
chain: int,
732746
chains: int,
733747
start: Sequence[PointType],
734-
random_seed,
748+
random_seed: RandomSeed,
735749
step,
736750
tune: int,
737751
model,
@@ -751,8 +765,7 @@ def _sample_population(
751765
The total number of chains in the population
752766
start : list
753767
Start points for each chain
754-
random_seed : int or list of ints, optional
755-
A list is accepted if more if ``cores`` is greater than one.
768+
random_seed : single random seed, optional
756769
step : function
757770
Step function (should be or contain a population step method)
758771
tune : int
@@ -793,7 +806,7 @@ def _sample(
793806
*,
794807
chain: int,
795808
progressbar: bool,
796-
random_seed,
809+
random_seed: RandomSeed,
797810
start: PointType,
798811
draws: int,
799812
step=None,
@@ -815,8 +828,7 @@ def _sample(
815828
Whether or not to display a progress bar in the command line. The bar shows the percentage
816829
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
817830
time until completion ("expected time of arrival"; ETA).
818-
random_seed : int or list of ints
819-
A list is accepted if ``cores`` is greater than one.
831+
random_seed : single random seed
820832
start : dict
821833
Starting point in parameter space (or partial point)
822834
draws : int
@@ -871,7 +883,7 @@ def iter_sample(
871883
chain: int = 0,
872884
tune: int = 0,
873885
model: Optional[Model] = None,
874-
random_seed: Optional[Union[int, List[int]]] = None,
886+
random_seed: RandomSeed = None,
875887
callback=None,
876888
) -> Iterator[MultiTrace]:
877889
"""Generate a trace on each iteration using the given step method.
@@ -896,8 +908,7 @@ def iter_sample(
896908
tune : int, optional
897909
Number of iterations to tune (defaults to 0).
898910
model : Model (optional if in ``with`` context)
899-
random_seed : int or list of ints, optional
900-
A list is accepted if more if ``cores`` is greater than one.
911+
random_seed : single random seed, optional
901912
callback :
902913
A function which gets called for every sample from the trace of a chain. The function is
903914
called with the trace and the current draw and will contain all samples for a single trace.
@@ -930,7 +941,7 @@ def _iter_sample(
930941
chain: int = 0,
931942
tune: int = 0,
932943
model=None,
933-
random_seed=None,
944+
random_seed: RandomSeed = None,
934945
callback=None,
935946
) -> Iterator[Tuple[BaseTrace, bool]]:
936947
"""Generator for sampling one chain. (Used in singleprocess sampling.)
@@ -953,8 +964,7 @@ def _iter_sample(
953964
tune : int, optional
954965
Number of iterations to tune (defaults to 0).
955966
model : Model (optional if in ``with`` context)
956-
random_seed : int or list of ints, optional
957-
A list is accepted if more if ``cores`` is greater than one.
967+
random_seed : single random seed, optional
958968
959969
Yields
960970
------
@@ -1194,7 +1204,7 @@ def _prepare_iter_population(
11941204
parallelize: bool,
11951205
tune: int,
11961206
model=None,
1197-
random_seed=None,
1207+
random_seed: RandomSeed = None,
11981208
progressbar=True,
11991209
) -> Iterator[Sequence[BaseTrace]]:
12001210
"""Prepare a PopulationStepper and traces for population sampling.
@@ -1214,8 +1224,7 @@ def _prepare_iter_population(
12141224
tune : int
12151225
Number of iterations to tune.
12161226
model : Model (optional if in ``with`` context)
1217-
random_seed : int or list of ints, optional
1218-
A list is accepted if more if ``cores`` is greater than one.
1227+
random_seed : single random seed, optional
12191228
progressbar : bool
12201229
``progressbar`` argument for the ``PopulationStepper``, (defaults to True)
12211230
@@ -1400,7 +1409,7 @@ def _mp_sample(
14001409
chains: int,
14011410
cores: int,
14021411
chain: int,
1403-
random_seed: list,
1412+
random_seed: Sequence[RandomSeed],
14041413
start: Sequence[PointType],
14051414
progressbar: bool = True,
14061415
trace: Optional[Union[BaseTrace, List[str]]] = None,
@@ -1426,7 +1435,7 @@ def _mp_sample(
14261435
The number of chains to run in parallel.
14271436
chain : int
14281437
Number of the first chain.
1429-
random_seed : list of ints
1438+
random_seed : list of random seeds
14301439
Random seeds for each chain.
14311440
start : list
14321441
Starting points for each chain.

0 commit comments

Comments
 (0)