104
104
PointList : TypeAlias = List [PointType ]
105
105
Backend : TypeAlias = Union [BaseTrace , MultiTrace , NDArray ]
106
106
107
+ RandomSeed = Optional [Union [int , Sequence [int ], np .ndarray ]]
108
+
107
109
_log = logging .getLogger ("pymc" )
108
110
109
111
@@ -437,14 +439,14 @@ def sample(
437
439
if random_seed == - 1 :
438
440
random_seed = None
439
441
if chains == 1 and isinstance (random_seed , int ):
440
- random_seed = [random_seed ]
442
+ random_seed_list = [random_seed ]
441
443
442
444
if random_seed is None or isinstance (random_seed , int ):
443
445
if random_seed is not None :
444
446
np .random .seed (random_seed )
445
- random_seed = [np .random .randint (2 ** 30 ) for _ in range (chains )]
447
+ random_seed_list = [np .random .randint (2 ** 30 ) for _ in range (chains )]
446
448
447
- if not isinstance (random_seed , abc .Iterable ):
449
+ if not isinstance (random_seed_list , abc .Iterable ):
448
450
raise TypeError ("Invalid value for `random_seed`. Must be tuple, list or int" )
449
451
450
452
if not discard_tuned_samples and not return_inferencedata :
@@ -490,7 +492,7 @@ def sample(
490
492
chains = chains ,
491
493
n_init = n_init ,
492
494
model = model ,
493
- seeds = random_seed ,
495
+ seeds = random_seed_list ,
494
496
progressbar = progressbar ,
495
497
jitter_max_retries = jitter_max_retries ,
496
498
tune = tune ,
@@ -506,7 +508,7 @@ def sample(
506
508
jitter_rvs = filter_rvs_to_jitter (step ),
507
509
chains = chains ,
508
510
)
509
- initial_points = [ipfn (seed ) for ipfn , seed in zip (ipfns , random_seed )]
511
+ initial_points = [ipfn (seed ) for ipfn , seed in zip (ipfns , random_seed_list )]
510
512
511
513
# One final check that shapes and logps at the starting points are okay.
512
514
for ip in initial_points :
@@ -523,7 +525,6 @@ def sample(
523
525
"tune" : tune ,
524
526
"progressbar" : progressbar ,
525
527
"model" : model ,
526
- "random_seed" : random_seed ,
527
528
"cores" : cores ,
528
529
"callback" : callback ,
529
530
"discard_tuned_samples" : discard_tuned_samples ,
@@ -542,6 +543,19 @@ def sample(
542
543
)
543
544
544
545
parallel = cores > 1 and chains > 1 and not has_population_samplers
546
+ # At some point it was decided that PyMC should not set a global seed by default,
547
+ # unless the user specified a seed. This is a symptom of the fact that PyMC samplers
548
+ # are built around global seeding. This branch makes sure we maintain this unspoken
549
+ # rule. See https://github.com/pymc-devs/pymc/pull/1395.
550
+ if parallel :
551
+ # For parallel sampling we can pass the list of random seeds directly, as
552
+ # global seeding will only be called inside each process
553
+ sample_args ["random_seed" ] = random_seed_list
554
+ else :
555
+ # We pass None if the original random seed was None. The single core sampler
556
+ # methods will only set a global seed when it is not None.
557
+ sample_args ["random_seed" ] = random_seed if random_seed is None else random_seed_list
558
+
545
559
t_start = time .time ()
546
560
if parallel :
547
561
_log .info (f"Multiprocess sampling ({ chains } chains in { cores } jobs)" )
@@ -674,7 +688,7 @@ def _sample_many(
674
688
chain : int ,
675
689
chains : int ,
676
690
start : Sequence [PointType ],
677
- random_seed : list ,
691
+ random_seed : Optional [ Sequence [ RandomSeed ]] ,
678
692
step ,
679
693
callback = None ,
680
694
** kwargs ,
@@ -691,7 +705,7 @@ def _sample_many(
691
705
Total number of chains to sample.
692
706
start: list
693
707
Starting points for each chain
694
- random_seed: list
708
+ random_seed: list of random seeds, optional
695
709
A list of seeds, one for each chain
696
710
step: function
697
711
Step function
@@ -708,7 +722,7 @@ def _sample_many(
708
722
chain = chain + i ,
709
723
start = start [i ],
710
724
step = step ,
711
- random_seed = random_seed [i ],
725
+ random_seed = None if random_seed is None else random_seed [i ],
712
726
callback = callback ,
713
727
** kwargs ,
714
728
)
@@ -731,7 +745,7 @@ def _sample_population(
731
745
chain : int ,
732
746
chains : int ,
733
747
start : Sequence [PointType ],
734
- random_seed ,
748
+ random_seed : RandomSeed ,
735
749
step ,
736
750
tune : int ,
737
751
model ,
@@ -751,8 +765,7 @@ def _sample_population(
751
765
The total number of chains in the population
752
766
start : list
753
767
Start points for each chain
754
- random_seed : int or list of ints, optional
755
- A list is accepted if more if ``cores`` is greater than one.
768
+ random_seed : single random seed, optional
756
769
step : function
757
770
Step function (should be or contain a population step method)
758
771
tune : int
@@ -793,7 +806,7 @@ def _sample(
793
806
* ,
794
807
chain : int ,
795
808
progressbar : bool ,
796
- random_seed ,
809
+ random_seed : RandomSeed ,
797
810
start : PointType ,
798
811
draws : int ,
799
812
step = None ,
@@ -815,8 +828,7 @@ def _sample(
815
828
Whether or not to display a progress bar in the command line. The bar shows the percentage
816
829
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
817
830
time until completion ("expected time of arrival"; ETA).
818
- random_seed : int or list of ints
819
- A list is accepted if ``cores`` is greater than one.
831
+ random_seed : single random seed
820
832
start : dict
821
833
Starting point in parameter space (or partial point)
822
834
draws : int
@@ -871,7 +883,7 @@ def iter_sample(
871
883
chain : int = 0 ,
872
884
tune : int = 0 ,
873
885
model : Optional [Model ] = None ,
874
- random_seed : Optional [ Union [ int , List [ int ]]] = None ,
886
+ random_seed : RandomSeed = None ,
875
887
callback = None ,
876
888
) -> Iterator [MultiTrace ]:
877
889
"""Generate a trace on each iteration using the given step method.
@@ -896,8 +908,7 @@ def iter_sample(
896
908
tune : int, optional
897
909
Number of iterations to tune (defaults to 0).
898
910
model : Model (optional if in ``with`` context)
899
- random_seed : int or list of ints, optional
900
- A list is accepted if more if ``cores`` is greater than one.
911
+ random_seed : single random seed, optional
901
912
callback :
902
913
A function which gets called for every sample from the trace of a chain. The function is
903
914
called with the trace and the current draw and will contain all samples for a single trace.
@@ -930,7 +941,7 @@ def _iter_sample(
930
941
chain : int = 0 ,
931
942
tune : int = 0 ,
932
943
model = None ,
933
- random_seed = None ,
944
+ random_seed : RandomSeed = None ,
934
945
callback = None ,
935
946
) -> Iterator [Tuple [BaseTrace , bool ]]:
936
947
"""Generator for sampling one chain. (Used in singleprocess sampling.)
@@ -953,8 +964,7 @@ def _iter_sample(
953
964
tune : int, optional
954
965
Number of iterations to tune (defaults to 0).
955
966
model : Model (optional if in ``with`` context)
956
- random_seed : int or list of ints, optional
957
- A list is accepted if more if ``cores`` is greater than one.
967
+ random_seed : single random seed, optional
958
968
959
969
Yields
960
970
------
@@ -1194,7 +1204,7 @@ def _prepare_iter_population(
1194
1204
parallelize : bool ,
1195
1205
tune : int ,
1196
1206
model = None ,
1197
- random_seed = None ,
1207
+ random_seed : RandomSeed = None ,
1198
1208
progressbar = True ,
1199
1209
) -> Iterator [Sequence [BaseTrace ]]:
1200
1210
"""Prepare a PopulationStepper and traces for population sampling.
@@ -1214,8 +1224,7 @@ def _prepare_iter_population(
1214
1224
tune : int
1215
1225
Number of iterations to tune.
1216
1226
model : Model (optional if in ``with`` context)
1217
- random_seed : int or list of ints, optional
1218
- A list is accepted if more if ``cores`` is greater than one.
1227
+ random_seed : single random seed, optional
1219
1228
progressbar : bool
1220
1229
``progressbar`` argument for the ``PopulationStepper``, (defaults to True)
1221
1230
@@ -1400,7 +1409,7 @@ def _mp_sample(
1400
1409
chains : int ,
1401
1410
cores : int ,
1402
1411
chain : int ,
1403
- random_seed : list ,
1412
+ random_seed : Sequence [ RandomSeed ] ,
1404
1413
start : Sequence [PointType ],
1405
1414
progressbar : bool = True ,
1406
1415
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
@@ -1426,7 +1435,7 @@ def _mp_sample(
1426
1435
The number of chains to run in parallel.
1427
1436
chain : int
1428
1437
Number of the first chain.
1429
- random_seed : list of ints
1438
+ random_seed : list of random seeds
1430
1439
Random seeds for each chain.
1431
1440
start : list
1432
1441
Starting points for each chain.
0 commit comments