| jupytext |
|
||||||||
|---|---|---|---|---|---|---|---|---|---|
| kernelspec |
|
(lecture_13)=
:::{post} Jan 7, 2024 :tags: statistical rethinking, bayesian inference, multilevel models :category: intermediate :author: Dustin Stansbury :::
This notebook is part of the PyMC port of the Statistical Rethinking 2023 lecture series by Richard McElreath.
Video - Lecture 13 - Multilevel Adventures
# Ignore warnings
import warnings
import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
import statsmodels.formula.api as smf
import utils as utils
import xarray as xr
from matplotlib import pyplot as plt
from matplotlib import style
from scipy import stats as stats
warnings.filterwarnings("ignore")
# Set matplotlib style
STYLE = "statistical-rethinking-2023.mplstyle"
style.use(STYLE)
- Establish Estimand
- Build Scientific model(s) (i.e. Causal model), depending on 1.
- Use 1 & 2 to build a Statistical Model
- Simulate data from from 2 and Validate you can recover from 3
- Analyze Real Data using 3.
In real life, it's never a linear path; you are jumping back and forth, iterating on 2-5, much like a branching path/choose your own adventure book.
Similarly there is no one-size-fits all approach to applying the methods in this course. In order to optimize for success when applying these methods, McElreath suggests a few strategies ("paths") moving forward:
- Return to the start -- McElreath suggests to return to the beginning of the course, reviewing the material now that you've observed a lion's share of the building blocks.
- It turns out that the notes presented in this repo are after the Third Pass of the course. I can't recommend strongly enough to take McElreath's advice and review the material from the beginning. It's a lot of material, but I was flabbergasted with how much I had forgotten in the short time between this lecture and the earlier ones. Similarly, I was surprised by how much easier it was to soak up the material the 2nd time around--a real testament to McElreath's outstanding teaching style.
- Skim & Index -- don't sweat the deatils, just aquaint yourself with the possiblities.
- One thing that I've found useful is to compile each of the model classes discussed in the course into a "recipe book" or "toolbox" of plug-and-play models that can be reused for different applications
- Clusters: subgroups in the data (e.g. tanks, participants, stories, departments)
- Adding clusters is fairly simple
- requires more index variables; more population priors
- Adding clusters is fairly simple
- Features: aspects of the model (i.e. parameters) that vary by cluster (e.g. survival, average response rate, admission rate, etc.)
- Adding features requires more complexity
- more parametrers, particularly dimensions in each population prior
- Adding features requires more complexity
- Varying effect strategy: using repeat observations and partial pooling to estimate unmeasured features of clusters that have left an imprent on the data
- Predictive perspective: regularization
- Causal Perspective: unobserved confounds are terrifying, but leveraging repeat observations give us some hope at more accurate inference
utils.draw_causal_graph(
edge_list=[("G", "P"), ("G", "C"), ("P", "C"), ("U", "P"), ("U", "C")],
node_props={
"G": {"label": "Grandparents Ed, G"},
"P": {"label": "Parents Ed, P"},
"C": {"label": "Children's Ed, C"},
"U": {"label": "Neighborhood, U", "color": "blue"},
},
edge_props={("U", "P"): {"color": "blue"}, ("U", "C"): {"color": "blue"}},
)
- Neighborhood is a backdoor path confound that blocks mediation analysis of the direct effect of
$G$ on$C$ - but having repeat observations for neighborhoods U allows us to estimate the effects of this confound
+++
utils.draw_causal_graph(
edge_list=[
("X", "R"),
("S", "R"),
("E", "R"),
("G", "R"),
("Y", "R"),
("U", "R"),
("G", "E"),
("Y", "E"),
("E", "P"),
("U", "P"),
],
node_props={
"X": {"label": "Treatment, X"},
"R": {"label": "Response, R"},
"P": {"label": "Participation, P", "style": "dashed"},
"U": {"label": "Individual Traits, U", "color": "blue"},
"unobserved": {"style": "dashed"},
},
edge_props={("U", "P"): {"color": "blue"}, ("U", "R"): {"color": "blue"}},
)
- Individuals vary on how they react to the response scale, adding noise to our estimatates
- However, given that each participant has repeat observations, we use the repeats to estimate this noise.
- Similarly, individual traits may cause sampling bias through an unobserved participation node; we can use mixed effects to help address this sampling bias.
+++
- rather than partial pooling, no pooling.
- very few benefits to using fixed effects over varying reffects.
- e.g. less efficient
- focus on getting the story straight (generative model, causal graph), you can worry about the details of estimator efficiency, etc. later
- Varying effects models are always a good default, but
- how do you use more than one cluster
- Predictions is at the level of the hierarchy now, which level do we care about
- Sampling efficiency -- e.g. centered/non-centered priors
- Group-level confounding -- e.g. Full Luxury Bayes or Mundlak Machines. For details, see the BONUS section of Lecture 12 - Multilevel Models
+++
- 1989 Fertility Survey
- 1924 women, 61 districts
-
Outcome,
$C$ : contraceptive use (binary variable) -
Predictors: age,
$A$ # of living children$K$ , urban/rural location$U$ -
Potential (unobserved) confounds: Family traits,
$F$ -
District ID:
$D$
FERTILITY = utils.load_data("bangladesh")
FERTILITY.head()
---
jupyter:
source_hidden: true
---
utils.draw_causal_graph(
edge_list=[
("A", "C"),
("K", "C"),
("U", "C"),
("D", "C"),
("D", "U"),
("U", "K"),
("A", "K"),
("D", "K"),
("F", "C"),
("F", "K"),
],
node_props={
"A": {"label": "age, A"},
"K": {"label": "# kids, K"},
"U": {"label": "urbanity, U"},
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
"F": {"label": "family traits, F", "style": "dashed"},
"unobserved": {"style": "dashed"},
},
)
district_counts = FERTILITY.groupby("district").count()["woman"]
plt.bar(district_counts.index, district_counts)
plt.xlabel("district")
plt.ylabel("# of women")
plt.title("Variation in district-level sampling");
- Estimand: contraceptive use in each district; partially pooled
- Model:
- varying intecept/offset for each district
---
jupyter:
source_hidden: true
---
utils.draw_causal_graph(
edge_list=[
("D", "C"),
],
node_props={
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
},
)
USES_CONTRACEPTION = FERTILITY["use.contraception"].values.astype(int)
DISTRICT_ID, _ = pd.factorize(FERTILITY.district)
DISTRICT = np.arange(1, 62).astype(
int
) # note: district 54 has no data so we create it's dim by hand
with pm.Model(coords={"district": DISTRICT}) as district_model:
# Priors
## Global priors
sigma = pm.Exponential("sigma", 1) # variation amongst districts
alpha_bar = pm.Normal("alpha_bar", 0, 1) # the average district
# District-level priors
alpha = pm.Normal("alpha", alpha_bar, sigma, dims="district")
# p(contraceptive)
p_C = pm.Deterministic("p_C", pm.math.invlogit(alpha), dims="district")
# Likelihood
p = pm.math.invlogit(alpha[DISTRICT_ID])
C = pm.Bernoulli("C", p=p, observed=USES_CONTRACEPTION)
district_inference = pm.sample()
def plot_survival_posterior(
inference,
sigma=None,
color="C0",
var="p_C",
hdi_prob=0.89,
data_filter=None,
title=None,
ax=None,
):
if ax is None:
_, ax = plt.subplots(figsize=(10, 4))
def reorder_missing_district_param(vec):
"""
It appears that pymc tacks on the estimates for district with no data
(54) onto the end of the parameter vector, so we put it into the correct
position with this closure
"""
vec_ = vec.copy()
end = vec_[-1] # no data is district 54 (index 53)
vec_ = np.delete(vec_, -1)
vec_ = np.insert(vec_, 52, end)
return vec_
# Filter the dataset for urban/rural if requested
if data_filter == "urban":
data_mask = (FERTILITY.urban).astype(bool)
elif data_filter == "rural":
data_mask = (1 - FERTILITY.urban).astype(bool)
else:
data_mask = np.ones(len(FERTILITY)).astype(bool)
plot_data = FERTILITY[data_mask]
district_counts = plot_data.groupby("district").count()["woman"]
contraceptive_counts = plot_data.groupby("district").sum()["use.contraception"]
proportion_contraceptive = contraceptive_counts / district_counts
plt.sca(ax)
utils.plot_scatter(
xs=proportion_contraceptive.index,
ys=proportion_contraceptive.values,
color="k",
s=50,
zorder=3,
alpha=0.8,
label="raw proportions",
)
# Posterior per-district mean survival probability
posterior_mean = inference.posterior.mean(dim=("chain", "draw"))[var]
posterior_mean = reorder_missing_district_param(posterior_mean.values)
utils.plot_scatter(
DISTRICT, posterior_mean, color=color, zorder=50, alpha=0.8, label="posterior means"
)
# Posterior HDI error bars
hdis = az.hdi(inference.posterior, var_names=var, hdi_prob=hdi_prob)[var].values
error_upper = reorder_missing_district_param(hdis[:, 1]) - posterior_mean
error_lower = posterior_mean - reorder_missing_district_param(hdis[:, 0])
utils.plot_errorbar(
xs=DISTRICT,
ys=posterior_mean,
error_lower=error_lower,
error_upper=error_upper,
colors=color,
error_width=8,
)
# Add empirical mean
empirical_mean = FERTILITY[data_mask]["use.contraception"].mean()
plt.axhline(y=empirical_mean, c="k", linestyle="--", label="global mean")
plt.ylim([-0.05, 1.05])
plt.xlabel("district ")
plt.ylabel("prob. use contraception")
plt.title(title)
plt.legend();
plot_survival_posterior(district_inference, title="District Model")
Districts with small sample sizes (e.g. district 3) have
- larger error bars -- exhibiting more uncertainty estimate
- exhibit more shrinkage
- posteriors are pulled toward the global mean (dashed line)
- red circles are far from black circles) b.c. the model is less confident
Districts with large sample sizes (e.g. district 1) have
- smaller error bars -- more certainty about estimates
- less shrinkage
- posteriors are closer to the empirical observations for the district
Districts with no data (e.g. district 49) still have posteriors
- informed posterior from partial pooling
- mean is near globla mean
- error bars larger than for other districts (looks like i may have an indexing bug in my errorbar code--need to look into that)
+++
utils.draw_causal_graph(
edge_list=[
("A", "C"),
("K", "C"),
("U", "C"),
("D", "C"),
("D", "U"),
("U", "K"),
("A", "K"),
("D", "K"),
],
node_props={
"A": {"color": "lightgray"},
"K": {"color": "lightgray"},
},
edge_props={("A", "K"): {"color": "lightgray"}, ("A", "C"): {"color": "lightgray"}},
)
- What is the effect of urban living?
- Beware:
- district features have potential group-level confounds
- Total effect of
$U$ passes through$K$ - Do not stratify by
$K$ -- it's a collider, that opens up the district-level confound through$D$
+++
utils.draw_causal_graph(
edge_list=[
("D", "C"),
("U", "C"),
],
node_props={
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
},
)
+++
We use the non-centered prior version here -- details about non-centered will be discussed later
URBAN_CODED, URBAN = pd.factorize(FERTILITY.urban, sort=True)
with pm.Model(coords={"district": DISTRICT}) as district_urban_model:
# Mutable data
urban = pm.Data("urban", URBAN_CODED)
# Priors
# District offset
alpha_bar = pm.Normal("alpha_bar", 0, 1) # the average district
sigma = pm.Exponential("sigma", 1) # variation amongst districts
# Uncentered parameterization
z_alpha = pm.Normal("z_alpha", 0, 1, dims="district")
alpha = alpha_bar + z_alpha * sigma
# District / urban interaction
beta_bar = pm.Normal("beta_bar", 0, 1) # the average urban effect
tau = pm.Exponential("tau", 1) # variation amongst urban
# Uncentered parameterization
z_beta = pm.Normal("z_beta", 0, 1, dims="district")
beta = beta_bar + z_beta * tau
# Recored p(contraceptive)
p_C = pm.Deterministic("p_C", pm.math.invlogit(alpha + beta))
p_C_urban = pm.Deterministic("p_C_urban", pm.math.invlogit(alpha + beta))
p_C_rural = pm.Deterministic("p_C_rural", pm.math.invlogit(alpha))
# Likelihood
p = pm.math.invlogit(alpha[DISTRICT_ID] + beta[DISTRICT_ID] * urban)
C = pm.Bernoulli("C", p=p, observed=USES_CONTRACEPTION)
district_urban_inference = pm.sample(target_accept=0.95)
az.summary(district_urban_inference, var_names=["alpha_bar", "beta_bar", "tau", "sigma"])
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
for ii, (label, var) in enumerate(zip(["rural", "urban"], ["p_C_rural", "p_C_urban"])):
plot_survival_posterior(
district_urban_inference, color=f"C{ii}", var=var, data_filter=label, ax=axs[ii]
)
plt.title(label)
# Save fig for reference in next lecture
utils.savefig("fertility_posterior_means_rural_urban.png")
The plots above indicate that urban areas have
- higher overall rates of contraceptive use
- The error bars higher variance in urban areas
The plot below re-confirms that variance in contraceptive use is indeed larger in urban areas; the posterior standard deviation parameter for urban areas
for ii, (label, var) in enumerate(zip(["rural, $\\sigma$", "urban, $\\tau$"], ["sigma", "tau"])):
az.plot_dist(district_urban_inference.posterior[var], color=f"C{ii}", label=label)
def exponential_prior(x, lambda_=1):
return lambda_ * np.exp(-lambda_ * x)
xs = np.linspace(0, 1.2)
plt.plot(xs, exponential_prior(xs), label="prior", color="k", linestyle="--")
plt.xlim([0, 1.2])
plt.xlabel("posterior std dev")
plt.ylabel("density")
plt.legend();
posterior_mean = district_urban_inference.posterior.mean(dim=("chain", "draw"))
utils.plot_scatter(
posterior_mean["p_C_rural"], posterior_mean["p_C_urban"], color="C0", label="posterior means"
)
correlation = np.corrcoef(posterior_mean["p_C_rural"], posterior_mean["p_C_urban"])[0][1]
plt.axvline(0.5, color="gray", linestyle="dashed")
plt.axhline(0.5, color="gray", linestyle="dashed")
plt.xlabel("prob C (rural)")
plt.ylabel("prob C (urban)")
plt.title(f"Correlation: {correlation:1.2f}")
plt.legend()
plt.axis("square")
# Save fig for reference in next lecture
utils.savefig("fertility_p_C_rural_urban.png");
- There's more contraceptive use in urban areas than in rural areas--a majority of points are to the left of the vertical line
- There's a high amount of correlation (cc>0.7) between contraceptive use in rural and urban areas in each district.
We should be able to take advantage of this correlation information to make better estimates. MORE TO COME!
+++
- Clusters: distinct groups in the data
- Features: aspects of the model (e.g. parameters) that vary by cluster
- There is useful information transferred across features
- We can use partial pooling to efficiently estimate features, even in absence of data
+++
- Ported to PyMC by Dustin Stansbury (2024)
- Based on Statistical Rethinking (2023) lectures by Richard McElreath
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
:::{include} ../page_footer.md :::