Skip to content

Commit de47253

Browse files
test on windows and mac (#4269)
* run arviz compat on mac and windows * run test_distributions_random.py too on mac * update env file in windows, activate correct env * rename dataset_to_points_dict with DeprecationWarning * fix underlying bug behind DensityDist stackoverflow tl:dr: The hashable helper function did not appropiately deal with tuples (and the test case did not actually test the memoization). In the process of prior-predictive sampling a model involving a DensityDist, the _compile_theano_function function was called with arguments (sd__log_, []). The _compile_theano_function has a pm.memo.memoize-decorator, which relies on the pm.memo.hashable for hashing of typically unhashable objects. The "hashable" function incorrectly handled tuples, eventually causing a stackoverflow error on Windows. * rewrite dict as literal, use list comprehension * punctuation, naming Co-authored-by: Michael Osthege <[email protected]>
1 parent 580a32a commit de47253

File tree

9 files changed

+137
-29
lines changed

9 files changed

+137
-29
lines changed

.github/workflows/arviz_compat.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ jobs:
99
pytest:
1010
strategy:
1111
matrix:
12-
os: [ubuntu-18.04]
12+
os: [ubuntu-latest, macos-latest]
1313
floatx: [float64]
1414
test-subset:
15+
- pymc3/tests/test_distributions_random.py
1516
- pymc3/tests/test_sampling.py
1617
fail-fast: false
1718
runs-on: ${{ matrix.os }}

.github/workflows/windows.yml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: windows
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [master]
7+
8+
jobs:
9+
pytest:
10+
strategy:
11+
matrix:
12+
os: [windows-latest]
13+
floatx: [float64]
14+
test-subset:
15+
- pymc3/tests/test_distributions_random.py
16+
- pymc3/tests/test_sampling.py
17+
runs-on: ${{ matrix.os }}
18+
env:
19+
TEST_SUBSET: ${{ matrix.test-subset }}
20+
THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc.cxxflags='-march=core2'
21+
defaults:
22+
run:
23+
shell: bash -l {0}
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Cache conda
27+
uses: actions/cache@v1
28+
env:
29+
# Increase this value to reset cache if conda-envs/environment-dev-py37.yml has not changed
30+
CACHE_NUMBER: 0
31+
with:
32+
path: ~/conda_pkgs_dir
33+
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
34+
hashFiles('conda-envs/environment-dev-py37.yml') }}
35+
- uses: conda-incubator/setup-miniconda@v2
36+
with:
37+
activate-environment: pymc3-dev-py37
38+
channel-priority: strict
39+
environment-file: conda-envs/environment-dev-py37.yml
40+
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
41+
- run: |
42+
conda activate pymc3-dev-py37
43+
python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET

pymc3/distributions/posterior_predictive.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from ..exceptions import IncorrectArgumentsError
4444
from ..vartypes import theano_constant
45-
from ..util import dataset_to_point_dict, chains_and_samples, get_var_name
45+
from ..util import dataset_to_point_list, chains_and_samples, get_var_name
4646

4747
# Failing tests:
4848
# test_mixture_random_shape::test_mixture_random_shape
@@ -209,10 +209,10 @@ def fast_sample_posterior_predictive(
209209

210210
if isinstance(trace, InferenceData):
211211
nchains, ndraws = chains_and_samples(trace)
212-
trace = dataset_to_point_dict(trace.posterior)
212+
trace = dataset_to_point_list(trace.posterior)
213213
elif isinstance(trace, Dataset):
214214
nchains, ndraws = chains_and_samples(trace)
215-
trace = dataset_to_point_dict(trace)
215+
trace = dataset_to_point_list(trace)
216216
elif isinstance(trace, MultiTrace):
217217
nchains = trace.nchains
218218
ndraws = len(trace)

pymc3/memoize.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import pickle
16+
import dill
1717
import collections
1818
from .util import biwrap
1919

@@ -23,7 +23,16 @@
2323
@biwrap
2424
def memoize(obj, bound=False):
2525
"""
26-
An expensive memoizer that works with unhashables
26+
Decorator to apply memoization to expensive functions.
27+
It uses a custom `hashable` helper function to hash typically unhashable Python objects.
28+
29+
Parameters
30+
----------
31+
obj : callable
32+
the function to apply the caching to
33+
bound : bool
34+
indicates if the [obj] is a bound method (self as first argument)
35+
For bound methods, the cache is kept in a `_cache` attribute on [self].
2736
"""
2837
# this is declared not to be a bound method, so just attach new attr to obj
2938
if not bound:
@@ -40,7 +49,7 @@ def memoizer(*args, **kwargs):
4049
key = (hashable(args[1:]), hashable(kwargs))
4150
if not hasattr(args[0], "_cache"):
4251
setattr(args[0], "_cache", collections.defaultdict(dict))
43-
# do not add to cache regestry
52+
# do not add to cache registry
4453
cache = getattr(args[0], "_cache")[obj.__name__]
4554
if key not in cache:
4655
cache[key] = obj(*args, **kwargs)
@@ -75,19 +84,26 @@ def __setstate__(self, state):
7584
self.__dict__.update(state)
7685

7786

78-
def hashable(a):
87+
def hashable(a) -> int:
7988
"""
80-
Turn some unhashable objects into hashable ones.
89+
Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function.
90+
Lists and tuples are hashed based on their elements.
8191
"""
8292
if isinstance(a, dict):
83-
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
93+
# first hash the keys and values with hashable
94+
# then hash the tuple of int-tuples with the builtin
95+
return hash(tuple((hashable(k), hashable(v)) for k, v in a.items()))
96+
if isinstance(a, (tuple, list)):
97+
# lists are mutable and not hashable by default
98+
# for memoization, we need the hash to depend on the items
99+
return hash(tuple(hashable(i) for i in a))
84100
try:
85101
return hash(a)
86102
except TypeError:
87103
pass
88104
# Not hashable >>>
89105
try:
90-
return hash(pickle.dumps(a))
106+
return hash(dill.dumps(a))
91107
except Exception:
92108
if hasattr(a, "__dict__"):
93109
return hashable(a.__dict__)

pymc3/sampling.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
get_untransformed_name,
5757
is_transformed_name,
5858
get_default_varnames,
59-
dataset_to_point_dict,
59+
dataset_to_point_list,
6060
chains_and_samples,
6161
)
6262
from .vartypes import discrete_types
@@ -1648,9 +1648,9 @@ def sample_posterior_predictive(
16481648

16491649
_trace: Union[MultiTrace, PointList]
16501650
if isinstance(trace, InferenceData):
1651-
_trace = dataset_to_point_dict(trace.posterior)
1651+
_trace = dataset_to_point_list(trace.posterior)
16521652
elif isinstance(trace, xarray.Dataset):
1653-
_trace = dataset_to_point_dict(trace)
1653+
_trace = dataset_to_point_list(trace)
16541654
else:
16551655
_trace = trace
16561656

@@ -1786,10 +1786,10 @@ def sample_posterior_predictive_w(
17861786
n_samples = [
17871787
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
17881788
]
1789-
traces = [dataset_to_point_dict(trace.posterior) for trace in traces]
1789+
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
17901790
elif isinstance(traces[0], xarray.Dataset):
17911791
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
1792-
traces = [dataset_to_point_dict(trace) for trace in traces]
1792+
traces = [dataset_to_point_list(trace) for trace in traces]
17931793
else:
17941794
n_samples = [len(i) * i.nchains for i in traces]
17951795

pymc3/tests/test_distributions_random.py

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from scipy import linalg
2222
import numpy.random as nr
2323
import theano
24+
import sys
2425

2526
import pymc3 as pm
2627
from pymc3.distributions.dist_math import clipped_beta_rvs
@@ -713,6 +714,10 @@ def test_half_flat(self):
713714
def test_binomial(self):
714715
pymc3_random_discrete(pm.Binomial, {"n": Nat, "p": Unit}, ref_rand=st.binom.rvs)
715716

717+
@pytest.mark.xfail(
718+
sys.platform.startswith("win"),
719+
reason="Known issue: https://github.com/pymc-devs/pymc3/pull/4269",
720+
)
716721
def test_beta_binomial(self):
717722
pymc3_random_discrete(
718723
pm.BetaBinomial, {"n": Nat, "alpha": Rplus, "beta": Rplus}, ref_rand=self._beta_bin

pymc3/tests/test_memo.py

+47-11
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,57 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
import pymc3 as pm
1416

15-
from pymc3.memoize import memoize
17+
from pymc3 import memoize
1618

1719

18-
def getmemo():
19-
@memoize
20-
def f(a, b=("a")):
21-
return str(a) + str(b)
20+
def test_memo():
21+
def fun(inputs, suffix="_a"):
22+
return str(inputs) + str(suffix)
2223

23-
return f
24+
inputs = ["i1", "i2"]
25+
assert fun(inputs) == "['i1', 'i2']_a"
26+
assert fun(inputs, "_b") == "['i1', 'i2']_b"
2427

28+
funmem = memoize.memoize(fun)
29+
assert hasattr(fun, "cache")
30+
assert isinstance(fun.cache, dict)
31+
assert len(fun.cache) == 0
32+
33+
# call the memoized function with a list input
34+
# and check the size of the cache!
35+
assert funmem(inputs) == "['i1', 'i2']_a"
36+
assert funmem(inputs) == "['i1', 'i2']_a"
37+
assert len(fun.cache) == 1
38+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
39+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
40+
assert len(fun.cache) == 2
41+
42+
# add items to the inputs list (the list instance remains identical !!)
43+
inputs.append("i3")
44+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
45+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
46+
assert len(fun.cache) == 3
2547

26-
def test_memo():
27-
f = getmemo()
2848

29-
assert f("x", ["y", "z"]) == "x['y', 'z']"
30-
assert f("x", ["a", "z"]) == "x['a', 'z']"
31-
assert f("x", ["y", "z"]) == "x['y', 'z']"
49+
def test_hashing_of_rv_tuples():
50+
obs = np.random.normal(-1, 0.1, size=10)
51+
with pm.Model() as pmodel:
52+
mu = pm.Normal("mu", 0, 1)
53+
sd = pm.Gamma("sd", 1, 2)
54+
dd = pm.DensityDist(
55+
"dd",
56+
pm.Normal.dist(mu, sd).logp,
57+
random=pm.Normal.dist(mu, sd).random,
58+
observed=obs,
59+
)
60+
for freerv in [mu, sd, dd] + pmodel.free_RVs:
61+
for structure in [
62+
freerv,
63+
{"alpha": freerv, "omega": None},
64+
[freerv, []],
65+
(freerv, []),
66+
]:
67+
assert isinstance(memoize.hashable(structure), int)

pymc3/tests/test_sampling.py

-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import pytest
3232

3333

34-
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
3534
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
3635
class TestSample(SeededTest):
3736
def setup_method(self):
@@ -953,7 +952,6 @@ def test_shared(self):
953952
assert gen2["y"].shape == (draws, n2)
954953

955954
def test_density_dist(self):
956-
957955
obs = np.random.normal(-1, 0.1, size=10)
958956
with pm.Model():
959957
mu = pm.Normal("mu", 0, 1)

pymc3/util.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
from typing import List, Dict, Tuple, Union
18+
import warnings
1819

1920
import numpy as np
2021
import xarray
@@ -258,6 +259,14 @@ def enhanced(*args, **kwargs):
258259
# FIXME: this function is poorly named, because it returns a LIST of
259260
# points, not a dictionary of points.
260261
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
262+
warnings.warn(
263+
"dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!",
264+
DeprecationWarning,
265+
)
266+
return dataset_to_point_list(ds)
267+
268+
269+
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
261270
# grab posterior samples for each variable
262271
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
263272
# make dicts

0 commit comments

Comments
 (0)