Skip to content

Commit 611eceb

Browse files
authored
add potentials to priors when running abc (#4016)
* add potentials to priors when running abc * update release notes * add message about potentials being added to the prior term
1 parent ec2d79b commit 611eceb

File tree

4 files changed

+23
-2
lines changed

4 files changed

+23
-2
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).
1414
- Extend `keep_size` argument handling for `sample_posterior_predictive` and `fast_sample_posterior_predictive`, to work on arviz InferenceData and xarray Dataset input values. (see [PR #4006](https://github.com/pymc-devs/pymc3/pull/4006) and [Issue #4004](https://github.com/pymc-devs/pymc3/issues/4004).
1515
- SMC-ABC: add the wasserstein and energy distance functions. Refactor API, the distance, sum_stats and epsilon arguments are now passed `pm.Simulator` instead of `pm.sample_smc`. Add random method to `pm.Simulator`. Add option to save the simulated data. Improves LaTeX representation [#3996](https://github.com/pymc-devs/pymc3/pull/3996)
16+
- SMC-ABC: Allow use of potentials by adding them to the prior term. [#4016](https://github.com/pymc-devs/pymc3/pull/4016)
1617

1718
## PyMC3 3.9.2 (24 June 2020)
1819
### Maintenance

pymc3/smc/sample_smc.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def sample_smc(
137137
_log = logging.getLogger("pymc3")
138138
_log.info("Initializing SMC sampler...")
139139

140+
model = modelcontext(model)
140141
if cores is None:
141142
cores = _cpu_count()
142143

@@ -165,8 +166,10 @@ def sample_smc(
165166

166167
if kernel.lower() == "abc":
167168
warnings.warn(EXPERIMENTAL_WARNING)
168-
if len(modelcontext(model).observed_RVs) != 1:
169+
if len(model.observed_RVs) != 1:
169170
warnings.warn("SMC-ABC only works properly with models with one observed variable")
171+
if model.potentials:
172+
_log.info("Potentials will be added to the prior term")
170173

171174
params = (
172175
draws,

pymc3/smc/smc.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from scipy.special import logsumexp
1919
from theano import function as theano_function
20+
import theano.tensor as tt
2021

2122
from ..model import modelcontext, Point
2223
from ..theanof import floatX, inputvars, make_shared_replacements, join_nonshared_inputs
@@ -100,9 +101,11 @@ def setup_kernel(self):
100101
Set up the likelihood logp function based on the chosen kernel
101102
"""
102103
shared = make_shared_replacements(self.variables, self.model)
103-
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
104104

105105
if self.kernel.lower() == "abc":
106+
factors = [var.logpt for var in self.model.free_RVs]
107+
factors += [tt.sum(factor) for factor in self.model.potentials]
108+
self.prior_logp_func = logp_forw([tt.sum(factors)], self.variables, shared)
106109
simulator = self.model.observed_RVs[0]
107110
distance = simulator.distribution.distance
108111
sum_stat = simulator.distribution.sum_stat
@@ -120,6 +123,7 @@ def setup_kernel(self):
120123
self.save_sim_data,
121124
)
122125
elif self.kernel.lower() == "metropolis":
126+
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
123127
self.likelihood_logp_func = logp_forw([self.model.datalogpt], self.variables, shared)
124128

125129
def initialize_logp(self):

pymc3/tests/test_smc.py

+13
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ def abs_diff(eps, obs_data, sim_data):
130130
observed=self.data,
131131
)
132132

133+
with pm.Model() as self.SMABC_potential:
134+
a = pm.Normal("a", mu=0, sigma=1)
135+
b = pm.HalfNormal("b", sigma=1)
136+
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
137+
s = pm.Simulator(
138+
"s", normal_sim, params=(a, b), sum_stat="sort", epsilon=1, observed=self.data
139+
)
140+
133141
def test_one_gaussian(self):
134142
with self.SMABC_test:
135143
trace = pm.sample_smc(draws=1000, kernel="ABC")
@@ -157,6 +165,11 @@ def test_custom_dist_sum(self):
157165
with self.SMABC_test2:
158166
trace = pm.sample_smc(draws=1000, kernel="ABC")
159167

168+
def test_potential(self):
169+
with self.SMABC_potential:
170+
trace = pm.sample_smc(draws=1000, kernel="ABC")
171+
assert np.all(trace["a"] >= 0)
172+
160173
def test_automatic_use_of_sort(self):
161174
with pm.Model() as model:
162175
s_g = pm.Simulator(

0 commit comments

Comments
 (0)