Skip to content

Commit 47da13d

Browse files
committed
Run with explained variance calculations
Signed-off-by: Nathaniel <[email protected]>
1 parent d15639a commit 47da13d

File tree

2 files changed

+1312
-831
lines changed

2 files changed

+1312
-831
lines changed

examples/case_studies/CFA_SEM.ipynb

Lines changed: 1247 additions & 826 deletions
Large diffs are not rendered by default.

examples/case_studies/CFA_SEM.myst.md

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ del idata
252252

253253
### Intermediate Cross-Loading Model
254254

255-
The idea of a measurment is maybe a little opaque when we only see models that fit well. Instead we want to briefly show how a in-apt measurement model gets reflected in the estimated parameters for the factor loadings. Here we specify a measurement model which attempts to couple the `se_social` and `sup_parents` indicators and bundle them into the same factor.
255+
The idea of a measurment is maybe a little opaque when we only see models that fit well. Instead we want to briefly show how an in-apt measurement model gets reflected in the estimated parameters for the factor loadings. Here we specify a measurement model which attempts to couple the `se_social` and `sup_parents` indicators and bundle them into the same factor.
256256

257257
```{code-cell} ipython3
258258
coords = {
@@ -352,19 +352,23 @@ Again our model samples well but the parameter estimates suggest that there is s
352352
az.summary(idata, var_names=["lambdas1", "lambdas2"])
353353
```
354354

355+
This is similarly refected in the diagnostic energy plots here too.
356+
355357
```{code-cell} ipython3
356358
fig, axs = plt.subplots(1, 2, figsize=(20, 9))
357359
axs = axs.flatten()
358360
az.plot_energy(idata, ax=axs[0])
359361
az.plot_forest(idata, var_names=["lambdas1"], combined=True, ax=axs[1]);
360362
```
361363

362-
This hints at a variety of measurement model misspecification and should force us back to the drawing board.
364+
This hints at a variety of measurement model misspecification and should force us back to the drawing board. An appropriate measurement model maps the indicator variables to a consistently defined latent construct that plausibly reflects aspects of the realised indicator metrics.
363365

364366
+++
365367

366368
## Full Measurement Model
367369

370+
With this in mind we'll now specify a full measurement that maps each of our thematically similar indicator metrics to an indicidual latent construct. This mandates the postulation of 5 distinct constructs.
371+
368372
```{code-cell} ipython3
369373
drivers = [
370374
"se_acad_p1",
@@ -509,7 +513,62 @@ ax.set_title("Residuals between Model Implied and Sample Covariances", fontsize=
509513
```
510514

511515
```{code-cell} ipython3
512-
make_ppc(idata_mm, 100, drivers=residuals_posterior_cov.columns, dims=(5, 3));
516+
make_ppc(idata_mm, 100, drivers=residuals_posterior_cov.columns, dims=(3, 5));
517+
```
518+
519+
```{code-cell} ipython3
520+
cov_df = pd.DataFrame(az.extract(idata_mm["posterior"])["cov"].mean(axis=2))
521+
cov_df.index = ["SE_ACAD", "SE_SOCIAL", "SUP_F", "SUP_P", "LS"]
522+
cov_df.columns = ["SE_ACAD", "SE_SOCIAL", "SUP_F", "SUP_P", "LS"]
523+
524+
correlation_df = pd.DataFrame(az.extract(idata_mm["posterior"])["chol_cov_corr"].mean(axis=2))
525+
correlation_df.index = ["SE_ACAD", "SE_SOCIAL", "SUP_F", "SUP_P", "LS"]
526+
correlation_df.columns = ["SE_ACAD", "SE_SOCIAL", "SUP_F", "SUP_P", "LS"]
527+
528+
factor_loadings = pd.DataFrame(
529+
az.summary(idata_mm, var_names=["lambdas1", "lambdas2", "lambdas3", "lambdas4", "lambdas5"])[
530+
"mean"
531+
]
532+
).reset_index()
533+
factor_loadings["factor"] = factor_loadings["index"].str.split("[", expand=True)[0]
534+
factor_loadings.columns = ["factor_loading", "factor_loading_weight", "factor"]
535+
factor_loadings["factor_loading_weight_sq"] = factor_loadings["factor_loading_weight"] ** 2
536+
factor_loadings["sum_sq_loadings"] = factor_loadings.groupby("factor")[
537+
"factor_loading_weight_sq"
538+
].transform(sum)
539+
factor_loadings["error_variances"] = az.summary(idata_mm, var_names=["Psi"])["mean"].values
540+
factor_loadings["total_indicator_variance"] = (
541+
factor_loadings["factor_loading_weight_sq"] + factor_loadings["error_variances"]
542+
)
543+
factor_loadings["total_variance"] = factor_loadings["total_indicator_variance"].sum()
544+
factor_loadings["indicator_explained_variance"] = (
545+
factor_loadings["factor_loading_weight_sq"] / factor_loadings["total_variance"]
546+
)
547+
factor_loadings["factor_explained_variance"] = (
548+
factor_loadings["sum_sq_loadings"] / factor_loadings["total_variance"]
549+
)
550+
factor_loadings.style.background_gradient(
551+
axis=0, subset=["indicator_explained_variance", "factor_explained_variance"]
552+
)
553+
```
554+
555+
```{code-cell} ipython3
556+
fig, ax = plt.subplots(figsize=(20, 6))
557+
temp = factor_loadings[["factor_loading", "indicator_explained_variance"]].sort_values(
558+
by="indicator_explained_variance"
559+
)
560+
ax.barh(temp["factor_loading"], temp["indicator_explained_variance"], align="center")
561+
ax.set_title("Explained Variance")
562+
```
563+
564+
```{code-cell} ipython3
565+
fig, axs = plt.subplots(1, 2, figsize=(20, 6))
566+
axs = axs.flatten()
567+
mask = np.triu(np.ones_like(cov_df, dtype=bool))
568+
sns.heatmap(cov_df, annot=True, cmap="Blues", ax=axs[0], mask=mask)
569+
axs[0].set_title("Covariance of Latent Constructs")
570+
axs[1].set_title("Covariance of Latent Constructs")
571+
sns.heatmap(correlation_df, annot=True, cmap="Blues", ax=axs[1], mask=mask);
513572
```
514573

515574
## Bayesian Structural Equation Models
@@ -596,8 +655,9 @@ def make_indirect_sem(priors):
596655
# Regression Components
597656
beta_r = pm.Normal("beta_r", 0, priors["beta_r"], dims="latent_regression")
598657
beta_r2 = pm.Normal("beta_r2", 0, priors["beta_r2"], dims="regression")
658+
sd_dist1 = pm.Exponential.dist(1.0, shape=2)
599659
resid_chol, _, _ = pm.LKJCholeskyCov(
600-
"resid_chol", n=2, eta=priors["eta"], sd_dist=sd_dist, compute_corr=True
660+
"resid_chol", n=2, eta=1, sd_dist=sd_dist1, compute_corr=True
601661
)
602662
_ = pm.Deterministic("resid_cov", chol.dot(chol.T))
603663
sigmas_resid = pm.MvNormal("sigmas_resid", kappa, chol=resid_chol)
@@ -673,7 +733,7 @@ model_sem2, idata_sem2 = make_indirect_sem(
673733
```
674734

675735
```{code-cell} ipython3
676-
fig, ax = plt.subplots(figsize=(20, 15))
736+
fig, ax = plt.subplots(figsize=(10, 15))
677737
az.plot_forest(
678738
[idata_sem0, idata_sem1, idata_sem2],
679739
model_names=["SEM0", "SEM1", "SEM2"],

0 commit comments

Comments
 (0)