Skip to content

Commit e783106

Browse files
authored
Better handling of variable names in e.g. GraphViz graphs (#4403)
* handle changed API of theano.gof.graph.stack_search * convert n and eta to tensors, explicitly list parameters for repr * improve robustness of get_repr_for_variable * Revert "handle changed API of theano.gof.graph.stack_search" This reverts commit 6238bff.
1 parent 240c372 commit e783106

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pymc3/distributions/multivariate.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -963,8 +963,8 @@ class _LKJCholeskyCov(Continuous):
963963
"""
964964

965965
def __init__(self, eta, n, sd_dist, *args, **kwargs):
966-
self.n = n
967-
self.eta = eta
966+
self.n = tt.as_tensor_variable(n)
967+
self.eta = tt.as_tensor_variable(eta)
968968

969969
if "transform" in kwargs and kwargs["transform"] is not None:
970970
raise ValueError("Invalid parameter: transform.")
@@ -1129,6 +1129,9 @@ def random(self, point=None, size=None):
11291129
samples = np.reshape(samples, size + sample_shape)
11301130
return samples
11311131

1132+
def _distr_parameters_for_repr(self):
1133+
return ["eta", "n"]
1134+
11321135

11331136
def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs):
11341137
R"""Wrapper function for covariance matrix with LKJ distributed correlations.

pymc3/util.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@ def get_default_varnames(var_iterator, include_transformed):
131131

132132
def get_repr_for_variable(variable, formatting="plain"):
133133
"""Build a human-readable string representation for a variable."""
134-
name = variable.name if variable is not None else None
134+
if variable is not None and hasattr(variable, "name"):
135+
name = variable.name
136+
elif type(variable) in [float, int, str]:
137+
name = str(variable)
138+
else:
139+
name = None
140+
135141
if name is None and variable is not None:
136142
if hasattr(variable, "get_parents"):
137143
try:

0 commit comments

Comments
 (0)