57
57
58
58
import pymc as pm
59
59
60
- from pymc .aesaraf import at_rng , compile_pymc , identity , rvs_to_value_vars
60
+ from pymc .aesaraf import (
61
+ SeedSequenceSeed ,
62
+ at_rng ,
63
+ compile_pymc ,
64
+ find_rng_nodes ,
65
+ identity ,
66
+ reseed_rngs ,
67
+ rvs_to_value_vars ,
68
+ )
61
69
from pymc .backends import NDArray
62
70
from pymc .blocking import DictToArrayBijection
63
71
from pymc .initial_point import make_initial_point_fn
64
72
from pymc .model import modelcontext
73
+ from pymc .sampling import RandomState , _get_seeds_per_chain
65
74
from pymc .util import WithMemoization , locally_cachedmethod
66
75
from pymc .variational .updates import adagrad_window
67
76
from pymc .vartypes import discrete_types
@@ -1641,22 +1650,30 @@ def sample_dict_fn(self):
1641
1650
sampled = [self .rslice (name ) for name in names ]
1642
1651
sampled = self .set_size_and_deterministic (sampled , s , 0 )
1643
1652
sample_fn = compile_pymc ([s ], sampled )
1653
+ rng_nodes = find_rng_nodes (sampled )
1644
1654
1645
- def inner (draws = 100 ):
1655
+ def inner (draws = 100 , * , random_seed : SeedSequenceSeed = None ):
1656
+ if random_seed is not None :
1657
+ reseed_rngs (rng_nodes , random_seed )
1646
1658
_samples = sample_fn (draws )
1659
+
1647
1660
return {v_ : s_ for v_ , s_ in zip (names , _samples )}
1648
1661
1649
1662
return inner
1650
1663
1651
- def sample (self , draws = 500 , return_inferencedata = True , ** kwargs ):
1664
+ def sample (
1665
+ self , draws = 500 , * , random_seed : RandomState = None , return_inferencedata = True , ** kwargs
1666
+ ):
1652
1667
"""Draw samples from variational posterior.
1653
1668
1654
1669
Parameters
1655
1670
----------
1656
- draws: ` int`
1671
+ draws : int
1657
1672
Number of random samples.
1658
- return_inferencedata: `bool`
1659
- Return trace in Arviz format
1673
+ random_seed : int, RandomState or Generator, optional
1674
+ Seed for the random number generator.
1675
+ return_inferencedata : bool
1676
+ Return trace in Arviz format.
1660
1677
1661
1678
Returns
1662
1679
-------
@@ -1666,7 +1683,9 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs):
1666
1683
# TODO: add tests for include_transformed case
1667
1684
kwargs ["log_likelihood" ] = False
1668
1685
1669
- samples = self .sample_dict_fn (draws ) # type: dict
1686
+ if random_seed is not None :
1687
+ (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
1688
+ samples = self .sample_dict_fn (draws , random_seed = random_seed ) # type: dict
1670
1689
points = ({name : records [i ] for name , records in samples .items ()} for i in range (draws ))
1671
1690
1672
1691
trace = NDArray (
0 commit comments