|
32 | 32 | import pymc3 as pm
|
33 | 33 |
|
34 | 34 | from pymc3.aesaraf import change_rv_size, floatX, intX
|
35 |
| -from pymc3.distributions.continuous import get_tau_sigma, interpolated |
| 35 | +from pymc3.distributions.continuous import ( |
| 36 | + _interpolated_argcdf, |
| 37 | + get_tau_sigma, |
| 38 | + interpolated, |
| 39 | +) |
36 | 40 | from pymc3.distributions.dist_math import clipped_beta_rvs
|
37 | 41 | from pymc3.distributions.multivariate import quaddist_matrix
|
38 | 42 | from pymc3.distributions.shape_utils import to_tuple
|
@@ -1242,7 +1246,28 @@ class TestOrderedProbit(BaseTestDistribution):
|
1242 | 1246 | ]
|
1243 | 1247 |
|
1244 | 1248 |
|
1245 |
| -class TestInterpolated(SeededTest): |
| 1249 | +class TestInterpolated(BaseTestDistribution): |
| 1250 | + def interpolated_rng_fn(self, size, mu, sigma, rng): |
| 1251 | + return st.norm.rvs(loc=mu, scale=sigma, size=size) |
| 1252 | + |
| 1253 | + pymc_dist = pm.Interpolated |
| 1254 | + |
| 1255 | + # Dummy values for RV size testing |
| 1256 | + mu = sigma = 1 |
| 1257 | + x_points = pdf_points = np.linspace(1, 100, 100) |
| 1258 | + |
| 1259 | + pymc_dist_params = {"x_points": x_points, "pdf_points": pdf_points} |
| 1260 | + reference_dist_params = {"mu": mu, "sigma": sigma} |
| 1261 | + |
| 1262 | + reference_dist = lambda self: functools.partial( |
| 1263 | + self.interpolated_rng_fn, rng=self.get_random_state() |
| 1264 | + ) |
| 1265 | + tests_to_run = [ |
| 1266 | + "check_rv_size", |
| 1267 | + ] |
| 1268 | + |
| 1269 | + |
| 1270 | +class TestInterpolatedSeeded(SeededTest): |
1246 | 1271 | @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
|
1247 | 1272 | def test_interpolated(self):
|
1248 | 1273 | for mu in R.vals:
|
|
0 commit comments