-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathdisaster_model_arbitrary_deterministic.py
50 lines (39 loc) · 1.92 KB
/
disaster_model_arbitrary_deterministic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Similar to disaster_model.py, but for arbitrary
determinsitics which are not not working with Theano.
Note that gradient based samplers will not work.
"""
import pymc3 as pm
import theano.tensor as tt
from numpy import arange, array
__all__ = ['disasters_data', 'switchpoint', 'early_mean', 'late_mean', 'rate',
'disasters']
# Time series of recorded coal mining disasters in the UK from 1851 to 1962
disasters_data = array([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = len(disasters_data)
with pm.Model() as model:
# Prior for distribution of switchpoint location
switchpoint = pm.DiscreteUniform('switchpoint', lower=0, upper=years)
# Priors for pre- and post-switch mean number of disasters
early_mean = pm.Exponential('early_mean', lam=1.)
late_mean = pm.Exponential('late_mean', lam=1.)
# Allocate appropriate Poisson rates to years before and after current
# switchpoint location
idx = arange(years)
rate = tt.switch(switchpoint >= idx, early_mean, late_mean)
# Data likelihood
disasters = pm.Poisson('disasters', rate, observed=disasters_data)
# Use slice sampler for means
step1 = pm.Slice([early_mean, late_mean])
# Use Metropolis for switchpoint, since it accomodates discrete variables
step2 = pm.Metropolis([switchpoint])
# Initial values for stochastic nodes
start = {'early_mean': 2., 'late_mean': 3.}
tr = pm.sample(1000, tune=500, start=start, step=[step1, step2], njobs=2)
pm.traceplot(tr)