Skip to content

Commit be048a4

Browse files
quantheoryricardoV94
authored andcommitted
Correct the order of rvs sent to compile_dlogp in find_MAP (#5923)
1 parent 729f79c commit be048a4

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

pymc/tests/test_starting.py

+26
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ def test_find_MAP():
9898
close_to(map_est2["sigma"], 1, tol)
9999

100100

101+
def test_find_MAP_issue_5923():
102+
# Test that gradient-based minimization works well regardless of the order
103+
# of variables in `vars`, and even when starting a reasonable distance from
104+
# the MAP.
105+
tol = 2.0**-11 # 16 bit machine epsilon, a low bar
106+
data = np.random.randn(100)
107+
# data should be roughly mean 0, std 1, but let's
108+
# normalize anyway to get it really close
109+
data = (data - np.mean(data)) / np.std(data)
110+
111+
with Model():
112+
mu = Uniform("mu", -1, 1)
113+
sigma = Uniform("sigma", 0.5, 1.5)
114+
Normal("y", mu=mu, tau=sigma**-2, observed=data)
115+
116+
start = {"mu": -0.5, "sigma": 1.25}
117+
map_est1 = starting.find_MAP(progressbar=False, vars=[mu, sigma], start=start)
118+
map_est2 = starting.find_MAP(progressbar=False, vars=[sigma, mu], start=start)
119+
120+
close_to(map_est1["mu"], 0, tol)
121+
close_to(map_est1["sigma"], 1, tol)
122+
123+
close_to(map_est2["mu"], 0, tol)
124+
close_to(map_est2["sigma"], 1, tol)
125+
126+
101127
def test_find_MAP_issue_4488():
102128
# Test for https://github.com/pymc-devs/pymc/issues/4488
103129
with Model() as m:

pymc/tuning/starting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,17 @@ def find_MAP(
110110
start = ipfn(seed)
111111
model.check_start_vals(start)
112112

113-
var_names = {var.name for var in vars}
113+
vars_dict = {var.name: var for var in vars}
114114
x0 = DictToArrayBijection.map(
115-
{var_name: value for var_name, value in start.items() if var_name in var_names}
115+
{var_name: value for var_name, value in start.items() if var_name in vars_dict}
116116
)
117117

118118
# TODO: If the mapping is fixed, we can simply create graphs for the
119119
# mapping and avoid all this bijection overhead
120120
compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start)
121121
logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info))
122122

123-
rvs = [model.values_to_rvs[value] for value in vars]
123+
rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _ in x0.point_map_info]
124124
try:
125125
# This might be needed for calls to `dlogp_func`
126126
# start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)

0 commit comments

Comments
 (0)