Skip to content

Commit ce07e70

Browse files
larryshamalamamichaelosthege
authored andcommitted
Remove self-directing arrow in observed nodes
1 parent 5d003dc commit ce07e70

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

Diff for: pymc/model_graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def make_compute_graph(
100100
if hasattr(var.tag, "observations"):
101101
try:
102102
obs_name = var.tag.observations.name
103-
if obs_name:
103+
if obs_name and obs_name != var_name:
104104
input_map[var_name] = input_map[var_name].difference({obs_name})
105105
input_map[obs_name] = input_map[obs_name].union({var_name})
106106
except AttributeError:

Diff for: pymc/tests/test_model_graph.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def model_with_dims():
132132
return pmodel, compute_graph, plates
133133

134134

135+
def model_unnamed_observed_node():
136+
"""
137+
Model at the source of the following issue: https://github.com/pymc-devs/pymc/issues/5892
138+
"""
139+
data = [-1, 0, 0.5, 1]
140+
141+
with pm.Model() as model:
142+
mu = pm.Normal(name="mu", mu=0.0, sigma=5.0)
143+
y = pm.Normal(name="y", mu=mu, sigma=3.0, observed=data)
144+
145+
compute_graph = {
146+
"mu": set(),
147+
"y": {"mu"},
148+
}
149+
plates = {
150+
"": {"mu"},
151+
"4": {"y"},
152+
}
153+
154+
return model, compute_graph, plates
155+
156+
135157
class BaseModelGraphTest(SeededTest):
136158
model_func = None
137159

@@ -202,21 +224,16 @@ def model_with_different_descendants():
202224
return pmodel2
203225

204226

205-
class TestParents:
206-
@pytest.mark.parametrize(
207-
"var_name, parent_names",
208-
[
209-
("L", {"pred"}),
210-
("pred", {"intermediate"}),
211-
("intermediate", {"a", "b"}),
212-
("c", {"a", "b"}),
213-
("a", set()),
214-
("b", set()),
215-
],
216-
)
217-
def test_get_parent_names(self, var_name, parent_names):
218-
mg = ModelGraph(model_with_different_descendants())
219-
mg.get_parent_names(mg.model[var_name]) == parent_names
227+
class TestImputationModel(BaseModelGraphTest):
228+
model_func = model_with_imputations
229+
230+
231+
class TestModelWithDims(BaseModelGraphTest):
232+
model_func = model_with_dims
233+
234+
235+
class TestUnnamedObservedNodes(BaseModelGraphTest):
236+
model_func = model_unnamed_observed_node
220237

221238

222239
class TestVariableSelection:
@@ -260,11 +277,3 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
260277
mg = ModelGraph(model_with_different_descendants())
261278
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
262279
assert mg.make_compute_graph(var_names=var_names) == compute_graph
263-
264-
265-
class TestImputationModel(BaseModelGraphTest):
266-
model_func = model_with_imputations
267-
268-
269-
class TestModelWithDims(BaseModelGraphTest):
270-
model_func = model_with_dims

0 commit comments

Comments
 (0)