Skip to content

Commit 6ced6a9

Browse files
authored
Merge pull request #337 from NathanielF/add_dag_fix_bug
add Propensity Score DAG as Quasi-Experiment writeup
2 parents 0ffe348 + 6ecb992 commit 6ced6a9

File tree

4 files changed

+159
-8
lines changed

4 files changed

+159
-8
lines changed

causalpy/pymc_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,17 @@ def fit(self, X, t, coords):
453453
distributions. We overwrite the base method because the base method assumes
454454
a variable y and we use t to indicate the treatment variable here.
455455
"""
456+
# Ensure random_seed is used in sample_prior_predictive() and
457+
# sample_posterior_predictive() if provided in sample_kwargs.
458+
random_seed = self.sample_kwargs.get("random_seed", None)
459+
456460
self.build_model(X, t, coords)
457461
with self:
458462
self.idata = pm.sample(**self.sample_kwargs)
459-
self.idata.extend(pm.sample_prior_predictive())
463+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
460464
self.idata.extend(
461-
pm.sample_posterior_predictive(self.idata, progressbar=False)
465+
pm.sample_posterior_predictive(
466+
self.idata, progressbar=False, random_seed=random_seed
467+
)
462468
)
463469
return self.idata

causalpy/tests/test_pymc_experiments.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ def test_regression_kink_gradient_change():
6363

6464
def test_inverse_prop():
6565
df = cp.load_data("nhefs")
66-
sample_kwargs = {"tune": 100, "draws": 100, "chains": 2, "cores": 2}
66+
sample_kwargs = {
67+
"tune": 100,
68+
"draws": 500,
69+
"chains": 2,
70+
"cores": 2,
71+
"random_seed": 100,
72+
}
6773
result = cp.pymc_experiments.InversePropensityWeighting(
6874
df,
6975
formula="trt ~ 1 + age + race",
@@ -93,7 +99,7 @@ def test_inverse_prop():
9399
assert isinstance(ate_list, list)
94100
ate_list = result.get_ate(0, result.idata, method="overlap")
95101
assert isinstance(ate_list, list)
96-
fig = result.plot_ATE(prop_draws=10, ate_draws=10)
102+
fig = result.plot_ATE(prop_draws=1, ate_draws=10)
97103
assert isinstance(fig, mpl.figure.Figure)
98104
fig = result.plot_balance_ecdf("age")
99105
assert isinstance(fig, mpl.figure.Figure)
Lines changed: 58 additions & 0 deletions
Loading

docs/source/quasi_dags.ipynb

Lines changed: 85 additions & 4 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)