Skip to content

Commit b6660f9

Browse files
authored
minor correction in sampling.py and starting.py (#4458)
Make deepcopy of start dicts in pm.sample and `pm.find_MAP` to prevent inplace modification of user variables closes #4456
1 parent e467bb9 commit b6660f9

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
1313
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
1414
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
15+
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
1516

1617
## PyMC3 3.11.0 (21 January 2021)
1718

pymc3/sampling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import warnings
2323

2424
from collections import defaultdict
25-
from copy import copy
25+
from copy import copy, deepcopy
2626
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
2727

2828
import arviz
@@ -423,6 +423,7 @@ def sample(
423423
p 0.609 0.047 0.528 0.699
424424
"""
425425
model = modelcontext(model)
426+
start = deepcopy(start)
426427
if start is None:
427428
check_start_vals(model.test_point, model)
428429
else:

pymc3/tests/test_sampling.py

+21
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,27 @@ def callback(trace, draw):
285285
assert len(trace) == trace_cancel_length
286286

287287

288+
def test_sample_find_MAP_does_not_modify_start():
289+
# see https://github.com/pymc-devs/pymc3/pull/4458
290+
with pm.Model():
291+
pm.Lognormal("untransformed")
292+
293+
# make sure find_Map does not modify the start dict
294+
start = {"untransformed": 2}
295+
pm.find_MAP(start=start)
296+
assert start == {"untransformed": 2}
297+
298+
# make sure sample does not modify the start dict
299+
start = {"untransformed": 0.2}
300+
pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=3)
301+
assert start == {"untransformed": 0.2}
302+
303+
# make sure sample does not modify the start when passes as list of dict
304+
start = [{"untransformed": 2}, {"untransformed": 0.2}]
305+
pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=2)
306+
assert start == [{"untransformed": 2}, {"untransformed": 0.2}]
307+
308+
288309
def test_empty_model():
289310
with pm.Model():
290311
pm.Normal("a", observed=1)

pymc3/tuning/starting.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
1818
@author: johnsalvatier
1919
"""
20+
import copy
21+
2022
import numpy as np
2123
import theano.gradient as tg
2224

@@ -96,7 +98,7 @@ def find_MAP(
9698
vars = inputvars(vars)
9799
disc_vars = list(typefilter(vars, discrete_types))
98100
allinmodel(vars, model)
99-
101+
start = copy.deepcopy(start)
100102
if start is None:
101103
start = model.test_point
102104
else:

0 commit comments

Comments
 (0)