Skip to content

Commit 1184b0c

Browse files
Create arrow to observation nodes subject to arbitrary dtype casting in pm.model_to_graphviz (#6011)
* Adding graphviz area for observed nodes with dtype casting * Add unit test for obs dtype casting in ModelGraph * check data container for owner attribute * Adding comments for make_compute_graph pattern matching * Repositioned 'While True' loop to better account for arbitary number of dtype casts * Fix grammar in comment on dtype casting
1 parent 3361176 commit 1184b0c

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

pymc/model_graph.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from aesara.compile.sharedvalue import SharedVariable
2121
from aesara.graph import Apply
2222
from aesara.graph.basic import ancestors, walk
23+
from aesara.scalar.basic import Cast
24+
from aesara.tensor.elemwise import Elemwise
2325
from aesara.tensor.random.op import RandomVariable
2426
from aesara.tensor.var import TensorConstant, TensorVariable
2527

@@ -98,13 +100,28 @@ def make_compute_graph(
98100
input_map[var_name] = input_map[var_name].union(parent_name)
99101

100102
if hasattr(var.tag, "observations"):
101-
try:
102-
obs_name = var.tag.observations.name
103+
obs_node = var.tag.observations
104+
105+
# loop created so that the elif block can go through this again
106+
# and remove any intermediate ops, notably dtype casting, to observations
107+
while True:
108+
109+
obs_name = obs_node.name
103110
if obs_name and obs_name != var_name:
104111
input_map[var_name] = input_map[var_name].difference({obs_name})
105112
input_map[obs_name] = input_map[obs_name].union({var_name})
106-
except AttributeError:
107-
pass
113+
break
114+
elif (
115+
# for cases where observations are cast to a certain dtype
116+
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
117+
obs_node.owner
118+
and isinstance(obs_node.owner.op, Elemwise)
119+
and isinstance(obs_node.owner.op.scalar_op, Cast)
120+
):
121+
# we can retrieve the observation node by going up the graph
122+
obs_node = obs_node.owner.inputs[0]
123+
else:
124+
break
108125

109126
return input_map
110127

pymc/tests/test_model_graph.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
from aesara.compile.sharedvalue import SharedVariable
19+
from aesara.tensor.var import TensorConstant
1920

2021
import pymc as pm
2122

@@ -154,6 +155,25 @@ def model_unnamed_observed_node():
154155
return model, compute_graph, plates
155156

156157

158+
def model_observation_dtype_casting():
159+
"""
160+
Model at the source of the following issue: https://github.com/pymc-devs/pymc/issues/5795
161+
"""
162+
with pm.Model() as model:
163+
data = pm.ConstantData("data", [0, 0, 1, 1], dtype=int)
164+
p = pm.Beta("p", 1, 1)
165+
bern = pm.Bernoulli("response", p, observed=data)
166+
167+
compute_graph = {
168+
"p": set(),
169+
"response": {"p"},
170+
"data": {"response"},
171+
}
172+
plates = {"": {"p"}, "4": {"data", "response"}}
173+
174+
return model, compute_graph, plates
175+
176+
157177
class BaseModelGraphTest(SeededTest):
158178
model_func = None
159179

@@ -166,7 +186,7 @@ def test_inputs(self):
166186
for child, parents_in_plot in self.compute_graph.items():
167187
var = self.model[child]
168188
parents_in_graph = self.model_graph.get_parent_names(var)
169-
if isinstance(var, SharedVariable):
189+
if isinstance(var, (SharedVariable, TensorConstant)):
170190
# observed data also doesn't have parents in the compute graph!
171191
# But for the visualization we like them to become decendants of the
172192
# RVs that these observations belong to.
@@ -236,6 +256,10 @@ class TestUnnamedObservedNodes(BaseModelGraphTest):
236256
model_func = model_unnamed_observed_node
237257

238258

259+
class TestObservationDtypeCasting(BaseModelGraphTest):
260+
model_func = model_observation_dtype_casting
261+
262+
239263
class TestVariableSelection:
240264
@pytest.mark.parametrize(
241265
"var_names, vars_to_plot, compute_graph",

0 commit comments

Comments
 (0)