Skip to content

Commit 66f078c

Browse files
Spaakcanyon289
andauthored
Fix issue with pickling Deterministic (#4120)
* Add eight schools pickle test * Add deterministic test * adding test for str() explicitly * Add end of line * fixes #4112, promote dynamic type to full class * removing unused import functools Co-authored-by: Ravin Kumar <[email protected]>
1 parent 3ae43cf commit 66f078c

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed

pymc3/model.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import collections
16-
import functools
1716
import itertools
1817
import threading
1918
import warnings
@@ -1903,14 +1902,22 @@ def _walk_up_rv(rv, formatting='plain'):
19031902
return all_rvs
19041903

19051904

1906-
def _repr_deterministic_rv(rv, formatting='plain'):
1907-
"""Make latex string for a Deterministic variable"""
1908-
if formatting == 'latex':
1909-
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1910-
name=rv.name, args=r",~".join(_walk_up_rv(rv, formatting=formatting)))
1911-
else:
1912-
return "{name} ~ Deterministic({args})".format(
1913-
name=rv.name, args=", ".join(_walk_up_rv(rv, formatting=formatting)))
1905+
class DeterministicWrapper(tt.TensorVariable):
1906+
def _str_repr(self, formatting='plain'):
1907+
if formatting == 'latex':
1908+
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1909+
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)))
1910+
else:
1911+
return "{name} ~ Deterministic({args})".format(
1912+
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting)))
1913+
1914+
def _repr_latex_(self):
1915+
return self._str_repr(formatting='latex')
1916+
1917+
__latex__ = _repr_latex_
1918+
1919+
def __str__(self):
1920+
return self._str_repr(formatting='plain')
19141921

19151922

19161923
def Deterministic(name, var, model=None, dims=None):
@@ -1929,15 +1936,7 @@ def Deterministic(name, var, model=None, dims=None):
19291936
var = var.copy(model.name_for(name))
19301937
model.deterministics.append(var)
19311938
model.add_random_variable(var, dims)
1932-
var._repr_latex_ = functools.partial(_repr_deterministic_rv, var, formatting='latex')
1933-
var.__latex__ = var._repr_latex_
1934-
1935-
# simply assigning var.__str__ is not enough, since str() will default to the class-
1936-
# defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028
1937-
old_type = type(var)
1938-
new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,),
1939-
{'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')})
1940-
var.__class__ = new_type
1939+
var.__class__ = DeterministicWrapper # adds str and latex functionality
19411940

19421941
return var
19431942

pymc3/tests/test_distributions.py

+8
Original file line numberDiff line numberDiff line change
@@ -1774,6 +1774,14 @@ def test___str__(self):
17741774
for str_repr in self.expected_str:
17751775
assert str_repr in model_str
17761776

1777+
def test_str(self):
1778+
for distribution, str_repr in zip(self.distributions, self.expected_str):
1779+
assert str(distribution) == str_repr
1780+
1781+
model_str = str(self.model)
1782+
for str_repr in self.expected_str:
1783+
assert str_repr in model_str
1784+
17771785

17781786
def test_discrete_trafo():
17791787
with pytest.raises(ValueError) as err:

pymc3/tests/test_model.py

+25
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import theano
1717
import theano.tensor as tt
1818
import numpy as np
19+
import pickle
1920
import pandas as pd
2021
import numpy.testing as npt
2122
import unittest
@@ -421,3 +422,27 @@ def test_tempered_logp_dlogp():
421422

422423
npt.assert_allclose(func_nograd(x), func(x)[0])
423424
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])
425+
426+
427+
def test_model_pickle(tmpdir):
428+
"""Tests that PyMC3 models are pickleable"""
429+
with pm.Model() as model:
430+
x = pm.Normal('x')
431+
pm.Normal('y', observed=1)
432+
433+
file_path = tmpdir.join("model.p")
434+
with open(file_path, 'wb') as buff:
435+
pickle.dump(model, buff)
436+
437+
438+
def test_model_pickle_deterministic(tmpdir):
439+
"""Tests that PyMC3 models are pickleable"""
440+
with pm.Model() as model:
441+
x = pm.Normal('x')
442+
z = pm.Normal("z")
443+
pm.Deterministic("w", x/z)
444+
pm.Normal('y', observed=1)
445+
446+
file_path = tmpdir.join("model.p")
447+
with open(file_path, 'wb') as buff:
448+
pickle.dump(model, buff)

0 commit comments

Comments
 (0)