|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import numpy as np |
| 15 | +import pymc3 as pm |
14 | 16 |
|
15 |
| -from pymc3.memoize import memoize |
| 17 | +from pymc3 import memoize |
16 | 18 |
|
17 | 19 |
|
18 |
| -def getmemo(): |
19 |
| - @memoize |
20 |
| - def f(a, b=("a")): |
21 |
| - return str(a) + str(b) |
| 20 | +def test_memo(): |
| 21 | + def fun(inputs, suffix="_a"): |
| 22 | + return str(inputs) + str(suffix) |
22 | 23 |
|
23 |
| - return f |
| 24 | + inputs = ["i1", "i2"] |
| 25 | + assert fun(inputs) == "['i1', 'i2']_a" |
| 26 | + assert fun(inputs, "_b") == "['i1', 'i2']_b" |
24 | 27 |
|
| 28 | + funmem = memoize.memoize(fun) |
| 29 | + assert hasattr(fun, "cache") |
| 30 | + assert isinstance(fun.cache, dict) |
| 31 | + assert len(fun.cache) == 0 |
| 32 | + |
| 33 | + # call the memoized function with a list input |
| 34 | + # and check the size of the cache! |
| 35 | + assert funmem(inputs) == "['i1', 'i2']_a" |
| 36 | + assert funmem(inputs) == "['i1', 'i2']_a" |
| 37 | + assert len(fun.cache) == 1 |
| 38 | + assert funmem(inputs, "_b") == "['i1', 'i2']_b" |
| 39 | + assert funmem(inputs, "_b") == "['i1', 'i2']_b" |
| 40 | + assert len(fun.cache) == 2 |
| 41 | + |
| 42 | + # add items to the inputs list (the list instance remains identical !!) |
| 43 | + inputs.append("i3") |
| 44 | + assert funmem(inputs) == "['i1', 'i2', 'i3']_a" |
| 45 | + assert funmem(inputs) == "['i1', 'i2', 'i3']_a" |
| 46 | + assert len(fun.cache) == 3 |
25 | 47 |
|
26 |
| -def test_memo(): |
27 |
| - f = getmemo() |
28 | 48 |
|
29 |
| - assert f("x", ["y", "z"]) == "x['y', 'z']" |
30 |
| - assert f("x", ["a", "z"]) == "x['a', 'z']" |
31 |
| - assert f("x", ["y", "z"]) == "x['y', 'z']" |
| 49 | +def test_hashing_of_rv_tuples(): |
| 50 | + obs = np.random.normal(-1, 0.1, size=10) |
| 51 | + with pm.Model() as pmodel: |
| 52 | + mu = pm.Normal("mu", 0, 1) |
| 53 | + sd = pm.Gamma("sd", 1, 2) |
| 54 | + dd = pm.DensityDist( |
| 55 | + "dd", |
| 56 | + pm.Normal.dist(mu, sd).logp, |
| 57 | + random=pm.Normal.dist(mu, sd).random, |
| 58 | + observed=obs, |
| 59 | + ) |
| 60 | + for freerv in [mu, sd, dd] + pmodel.free_RVs: |
| 61 | + for structure in [ |
| 62 | + freerv, |
| 63 | + {"alpha": freerv, "omega": None}, |
| 64 | + [freerv, []], |
| 65 | + (freerv, []), |
| 66 | + ]: |
| 67 | + assert isinstance(memoize.hashable(structure), int) |
0 commit comments