Skip to content

Commit f5fdebb

Browse files
SpaakColCarroll
andauthored
add distribution details to GraphViz output (#4159)
* use new str() representations in GraphViz output * typo * updating tests * add debug print info for Travis-CI * fix failing test due to floating point numerical accuracy * another attempt at fixing float accuracy test fail * ensure string format as proper floatX Co-authored-by: Colin <[email protected]>
1 parent a05684b commit f5fdebb

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

Diff for: pymc3/model_graph.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,22 @@ def _make_node(self, var_name, graph):
134134
if isinstance(v, SharedVariable):
135135
attrs["style"] = "rounded, filled"
136136

137-
# Get name for node
137+
# determine the shape for this node (default (Distribution) is ellipse)
138138
if v in self.model.potentials:
139-
distribution = "Potential"
140-
attrs["shape"] = "octagon"
141-
elif hasattr(v, "distribution"):
142-
distribution = v.distribution.__class__.__name__
139+
attrs['shape'] = 'octagon'
140+
elif isinstance(v, SharedVariable) or not hasattr(v, 'distribution'):
141+
# shared variables and Deterministic represented by a box
142+
attrs['shape'] = 'box'
143+
144+
if v in self.model.potentials:
145+
label = f'{var_name}\n~\nPotential'
143146
elif isinstance(v, SharedVariable):
144-
distribution = "Data"
145-
attrs["shape"] = "box"
147+
label = f'{var_name}\n~\nData'
146148
else:
147-
distribution = "Deterministic"
148-
attrs["shape"] = "box"
149+
label = str(v).replace(' ~ ', '\n~\n')
150+
151+
graph.node(var_name.replace(':', '&'), label, **attrs)
149152

150-
graph.node(var_name.replace(":", "&"), f"{var_name}\n~\n{distribution}", **attrs)
151153

152154
def get_plates(self):
153155
"""Rough but surprisingly accurate plate detection.

Diff for: pymc3/tests/test_data_container.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pymc3 as pm
16+
from ..theanof import floatX
1617
from .helpers import SeededTest
1718
import numpy as np
1819
import pandas as pd
@@ -174,7 +175,8 @@ def test_model_to_graphviz_for_model_with_data_container(self):
174175
x = pm.Data("x", [1.0, 2.0, 3.0])
175176
y = pm.Data("y", [1.0, 2.0, 3.0])
176177
beta = pm.Normal("beta", 0, 10.0)
177-
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
178+
obs_sigma = floatX(np.sqrt(1e-2))
179+
pm.Normal("obs", beta * x, obs_sigma, observed=y)
178180
pm.sample(1000, init=None, tune=1000, chains=1)
179181

180182
g = pm.model_to_graphviz(model)
@@ -183,9 +185,9 @@ def test_model_to_graphviz_for_model_with_data_container(self):
183185
text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]'
184186
assert text in g.source
185187
# Didn't break ordinary variables?
186-
text = 'beta [label="beta\n~\nNormal"]'
188+
text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]'
187189
assert text in g.source
188-
text = 'obs [label="obs\n~\nNormal" style=filled]'
190+
text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]'
189191
assert text in g.source
190192

191193
def test_explicit_coords(self):

0 commit comments

Comments
 (0)