Skip to content

Commit 9169718

Browse files
authored
Merge pull request #330 from pymc-labs/failing-doctests
Fix failing doctests
2 parents 15b5756 + 8805fb2 commit 9169718

File tree

2 files changed

+62
-12
lines changed

2 files changed

+62
-12
lines changed

causalpy/pymc_experiments.py

+36
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def print_coefficients(self, round_to=None) -> None:
9999
... "progressbar": False
100100
... }),
101101
... )
102+
<BLANKLINE>
103+
<BLANKLINE>
104+
<BLANKLINE>
105+
<BLANKLINE>
106+
<BLANKLINE>
102107
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
103108
Model coefficients:
104109
Intercept 1, 94% HDI [1, 1]
@@ -157,6 +162,11 @@ class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
157162
... }
158163
... ),
159164
... )
165+
<BLANKLINE>
166+
<BLANKLINE>
167+
<BLANKLINE>
168+
<BLANKLINE>
169+
<BLANKLINE>
160170
>>> result.summary(round_to=1) # doctest: +NUMBER
161171
==================================Pre-Post Fit==================================
162172
Formula: actual ~ 0 + a + g
@@ -383,6 +393,11 @@ class InterruptedTimeSeries(PrePostFit):
383393
... }
384394
... )
385395
... )
396+
<BLANKLINE>
397+
<BLANKLINE>
398+
<BLANKLINE>
399+
<BLANKLINE>
400+
<BLANKLINE>
386401
"""
387402

388403
expt_type = "Interrupted Time Series"
@@ -418,6 +433,11 @@ class SyntheticControl(PrePostFit):
418433
... }
419434
... ),
420435
... )
436+
<BLANKLINE>
437+
<BLANKLINE>
438+
<BLANKLINE>
439+
<BLANKLINE>
440+
<BLANKLINE>
421441
"""
422442

423443
expt_type = "Synthetic Control"
@@ -474,6 +494,11 @@ class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
474494
... }
475495
... )
476496
... )
497+
<BLANKLINE>
498+
<BLANKLINE>
499+
<BLANKLINE>
500+
<BLANKLINE>
501+
<BLANKLINE>
477502
"""
478503

479504
def __init__(
@@ -765,6 +790,11 @@ class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
765790
... ),
766791
... treatment_threshold=0.5,
767792
... )
793+
<BLANKLINE>
794+
<BLANKLINE>
795+
<BLANKLINE>
796+
<BLANKLINE>
797+
<BLANKLINE>
768798
"""
769799

770800
def __init__(
@@ -1174,6 +1204,10 @@ class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator):
11741204
... }
11751205
... )
11761206
... )
1207+
<BLANKLINE>
1208+
<BLANKLINE>
1209+
<BLANKLINE>
1210+
<BLANKLINE>
11771211
>>> result.summary(round_to=1) # doctest: +NUMBER
11781212
==================Pretest/posttest Nonequivalent Group Design===================
11791213
Formula: post ~ 1 + C(group) + pre
@@ -1404,6 +1438,8 @@ class InstrumentalVariable(ExperimentalDesign, IVDataValidator):
14041438
... formula=formula,
14051439
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
14061440
... )
1441+
<BLANKLINE>
1442+
<BLANKLINE>
14071443
"""
14081444

14091445
def __init__(

causalpy/pymc_models.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,17 @@ class ModelBuilder(pm.Model):
7272
... }
7373
... )
7474
>>> model.fit(X, y)
75-
Inference...
75+
<BLANKLINE>
76+
<BLANKLINE>
77+
Inference data...
7678
>>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
7779
>>> model.predict(X_new)
78-
Inference...
79-
>>> model.score(X, y) # doctest: +NUMBER
80-
r2 0.3
81-
r2_std 0.0
80+
<BLANKLINE>
81+
Inference data...
82+
>>> model.score(X, y)
83+
<BLANKLINE>
84+
r2 0.390344
85+
r2_std 0.081135
8286
dtype: float64
8387
"""
8488

@@ -112,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
112116

113117
# Ensure random_seed is used in sample_prior_predictive() and
114118
# sample_posterior_predictive() if provided in sample_kwargs.
115-
if "random_seed" in self.sample_kwargs:
116-
random_seed = self.sample_kwargs["random_seed"]
117-
else:
118-
random_seed = None
119+
random_seed = self.sample_kwargs.get("random_seed", None)
119120

120121
self.build_model(X, y, coords)
121122
with self:
@@ -137,10 +138,17 @@ def predict(self, X):
137138
138139
"""
139140

141+
# Ensure random_seed is used in sample_prior_predictive() and
142+
# sample_posterior_predictive() if provided in sample_kwargs.
143+
random_seed = self.sample_kwargs.get("random_seed", None)
144+
140145
self._data_setter(X)
141146
with self: # sample with new input data
142147
post_pred = pm.sample_posterior_predictive(
143-
self.idata, var_names=["y_hat", "mu"], progressbar=False
148+
self.idata,
149+
var_names=["y_hat", "mu"],
150+
progressbar=False,
151+
random_seed=random_seed,
144152
)
145153
return post_pred
146154

@@ -193,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
193201
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
194202
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
195203
>>> wsf.fit(X,y)
196-
Inference ...
204+
<BLANKLINE>
205+
<BLANKLINE>
206+
Inference data...
197207
""" # noqa: W605
198208

199209
def build_model(self, X, y, coords):
@@ -249,7 +259,9 @@ class LinearRegression(ModelBuilder):
249259
... 'obs_indx': np.arange(rd.shape[0])
250260
... },
251261
... )
252-
Inference...
262+
<BLANKLINE>
263+
<BLANKLINE>
264+
Inference data...
253265
""" # noqa: W605
254266

255267
def build_model(self, X, y, coords):
@@ -301,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
301313
... "eta": 2,
302314
... "lkj_sd": 2,
303315
... })
316+
<BLANKLINE>
317+
<BLANKLINE>
304318
Inference data...
305319
"""
306320

0 commit comments

Comments
 (0)