Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1f4f695
adding vs module
NathanielF Nov 10, 2025
2a01933
Merge branch 'main' into vs_module
NathanielF Nov 19, 2025
db1e81d
adding demo notebook
NathanielF Nov 20, 2025
db0522e
trying to fix doctests
NathanielF Nov 20, 2025
1670af4
adding fix
NathanielF Nov 20, 2025
83061e4
another fix
NathanielF Nov 20, 2025
05dc20d
update fix
NathanielF Nov 20, 2025
73e6a8d
adding tests
NathanielF Nov 20, 2025
8d6251f
update adding more tests
NathanielF Nov 21, 2025
4577106
add normal test vs_prior
NathanielF Nov 21, 2025
fd1bfb7
improving test coverage
NathanielF Nov 21, 2025
7375aa5
updating notebook
NathanielF Nov 21, 2025
896ee57
update index
NathanielF Nov 21, 2025
1bb79d3
Merge branch 'main' into vs_module
NathanielF Nov 21, 2025
068c922
update spelling
NathanielF Nov 21, 2025
13b320e
add binary treatment case to IV model
NathanielF Nov 22, 2025
6210116
adding more write up
NathanielF Nov 26, 2025
bf5b404
update heading
NathanielF Nov 26, 2025
b16ef30
hide cells
NathanielF Nov 26, 2025
b88271e
better story telling
NathanielF Dec 9, 2025
c452650
tidying
NathanielF Dec 9, 2025
a5dd60f
hide cell inputs
NathanielF Dec 9, 2025
2578cce
Merge branch 'main' into vs_module
NathanielF Dec 9, 2025
a7c1090
fixing linting
NathanielF Dec 9, 2025
78ed0ce
fix linting error
NathanielF Dec 9, 2025
9e6ede0
update typing
NathanielF Dec 9, 2025
dbc8614
spell check
NathanielF Dec 9, 2025
12936a4
update for juan's comments.
NathanielF Dec 22, 2025
97090b8
update tests to include outcome parameter in vs_hyperparams
NathanielF Dec 22, 2025
6e7accf
Merge branch 'main' into vs_module
NathanielF Dec 22, 2025
4b315fd
add type hints
NathanielF Dec 23, 2025
ca37024
bug fix
NathanielF Dec 23, 2025
18da6c4
fixing the warning message
NathanielF Dec 23, 2025
a466161
update iv_vs_priors notebook
NathanielF Dec 24, 2025
8020248
full notebook re-run
NathanielF Dec 25, 2025
452a63c
fix init import bug
NathanielF Dec 25, 2025
f377505
Merge branch 'main' into pr/568
drbenvincent Jan 2, 2026
2576b66
get pre-commit checks passing
drbenvincent Jan 2, 2026
4dae630
tidying spelling
NathanielF Jan 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@
"RegressionKink",
"skl_models",
"SyntheticControl",
"variable_selection_priors",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import for variable_selection_priors module export

The variable_selection_priors module is added to __all__ but there's no corresponding import statement in the file. Other modules like pymc_models and skl_models have explicit imports (e.g., import causalpy.pymc_models as pymc_models), but this pattern is missing for variable_selection_priors. This causes from causalpy import variable_selection_priors and from causalpy import * to fail with an import error.

Fix in Cursor Fix in Web

]
19 changes: 18 additions & 1 deletion causalpy/experiments/instrumental_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class InstrumentalVariable(BaseExperiment):
If priors are not specified we will substitute MLE estimates for
the beta coefficients. Example: ``priors = {"mus": [0, 0],
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
:param vs_prior_type : str or None, default=None
Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None.
If None, uses standard normal priors.
:param vs_hyperparams : dict, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sphinx format and not numpy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add that into AGENTS.md if it's not already there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sphinx format and not numpy

You got me. The doc strings were AI generated. Will fix.

Hyperparameters for variable selection priors. Only used if vs_prior_type
is not None.

Example
--------
Expand Down Expand Up @@ -98,6 +104,8 @@ def __init__(
formula: str,
model: BaseExperiment | None = None,
priors: dict | None = None,
vs_prior_type=None,
vs_hyperparams=None,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand All @@ -107,6 +115,8 @@ def __init__(
self.formula = formula
self.instruments_formula = instruments_formula
self.model = model
self.vs_prior_type = (vs_prior_type,)
self.vs_hyperparams = vs_hyperparams or {}
self.input_validation()

y, X = dmatrices(formula, self.data)
Expand Down Expand Up @@ -138,7 +148,14 @@ def __init__(
}
self.priors = priors
self.model.fit( # type: ignore[call-arg,union-attr]
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
X=self.X,
Z=self.Z,
y=self.y,
t=self.t,
coords=COORDS,
priors=self.priors,
vs_prior_type=vs_prior_type,
vs_hyperparams=vs_hyperparams,
)

def input_validation(self) -> None:
Expand Down
109 changes: 59 additions & 50 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pymc_extras.prior import Prior

from causalpy.utils import round_num
from causalpy.variable_selection_priors import VariableSelectionPrior


class PyMCModel(pm.Model):
Expand Down Expand Up @@ -657,7 +658,9 @@ def build_model( # type: ignore
y: np.ndarray,
t: np.ndarray,
coords: Dict[str, Any],
priors: Dict[str, Any],
priors,
vs_prior_type=None,
vs_hyperparams=None,
) -> None:
"""Specify model with treatment regression and focal regression
data and priors.
Expand All @@ -680,23 +683,47 @@ def build_model( # type: ignore
Dictionary of priors for the mus and sigmas of both
regressions. Example: ``priors = {"mus": [0, 0],
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
:param vs_prior_type: An optional string. Can be "spike_and_slab"
or "horseshoe" or "normal
:param vs_hyperparams: An optional dictionary of priors for the
variable selection hyperparameters

"""

# --- Priors ---
with self:
self.add_coords(coords)
beta_t = pm.Normal(
name="beta_t",
mu=priors["mus"][0],
sigma=priors["sigmas"][0],
dims="instruments",
)
beta_z = pm.Normal(
name="beta_z",
mu=priors["mus"][1],
sigma=priors["sigmas"][1],
dims="covariates",
)

# Create coefficient priors
if vs_prior_type:
# Use variable selection priors
vs_prior_treatment = VariableSelectionPrior(
vs_prior_type, vs_hyperparams
)
vs_prior_outcome = VariableSelectionPrior(vs_prior_type, vs_hyperparams)

beta_t = vs_prior_treatment.create_prior(
name="beta_t", n_params=Z.shape[1], dims="instruments", X=Z
)

beta_z = vs_prior_outcome.create_prior(
name="beta_z", n_params=X.shape[1], dims="covariates", X=X
)
else:
# Use standard normal priors
beta_t = pm.Normal(
name="beta_t",
mu=priors["mus"][0],
sigma=priors["sigmas"][0],
dims="instruments",
)
beta_z = pm.Normal(
name="beta_z",
mu=priors["mus"][1],
sigma=priors["sigmas"][1],
dims="covariates",
)

sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
chol, corr, sigmas = pm.LKJCholeskyCov(
name="chol_cov",
Expand Down Expand Up @@ -755,50 +782,32 @@ def sample_predictive_distribution(self, ppc_sampler: str | None = "jax") -> Non
)
)

def fit( # type: ignore
def fit(
self,
X: np.ndarray,
Z: np.ndarray,
y: np.ndarray,
t: np.ndarray,
coords: Dict[str, Any],
priors: Dict[str, Any],
ppc_sampler: str | None = None,
) -> az.InferenceData:
"""Draw samples from posterior distribution and potentially from
the prior and posterior predictive distributions.

Parameters
----------
X : np.ndarray
Array used to predict our outcome y.
Z : np.ndarray
Array used to predict our treatment variable t.
y : np.ndarray
Array of values representing our focal outcome y.
t : np.ndarray
Array representing the treatment variable.
coords : dict
Dictionary with coordinate names for named dimensions.
priors : dict
Dictionary of priors for the model.
ppc_sampler : str, optional
Sampler for posterior predictive distribution. Can be 'jax',
'pymc', or None. Defaults to None, so the user can determine
if they wish to spend time sampling the posterior predictive
distribution independently.

Returns
-------
az.InferenceData
InferenceData object containing the samples.
X,
Z,
y,
t,
coords,
priors,
ppc_sampler=None,
vs_prior_type=None,
vs_hyperparams=None,
):
"""Draw samples from posterior distribution and potentially
from the prior and posterior predictive distributions. The
fit call can take values for the
ppc_sampler = ['jax', 'pymc', None]
We default to None, so the user can determine if they wish
to spend time sampling the posterior predictive distribution
independently.
"""

# Ensure random_seed is used in sample_prior_predictive() and
# sample_posterior_predictive() if provided in sample_kwargs.
# Use JAX for ppc sampling of multivariate likelihood

self.build_model(X, Z, y, t, coords, priors)
self.build_model(X, Z, y, t, coords, priors, vs_prior_type, vs_hyperparams)
with self:
self.idata = pm.sample(**self.sample_kwargs)
self.sample_predictive_distribution(ppc_sampler=ppc_sampler)
Expand Down
32 changes: 32 additions & 0 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,38 @@ def test_iv_reg(mock_pymc_sample):
result.get_plot_data()


@pytest.mark.integration
def test_iv_reg_vs_prior(mock_pymc_sample):
df = cp.load_data("risk")
instruments_formula = "risk ~ 1 + logmort0"
formula = "loggdp ~ 1 + risk"
instruments_data = df[["risk", "logmort0"]]
data = df[["loggdp", "risk"]]

result = cp.InstrumentalVariable(
instruments_data=instruments_data,
data=data,
instruments_formula=instruments_formula,
formula=formula,
model=cp.pymc_models.InstrumentalVariableRegression(
sample_kwargs=sample_kwargs
),
vs_prior_type="spike_and_slab",
vs_hyperparams={"pi_alpha": 5},
)
result.model.sample_predictive_distribution(ppc_sampler="pymc")
assert isinstance(df, pd.DataFrame)
assert isinstance(data, pd.DataFrame)
assert isinstance(instruments_data, pd.DataFrame)
assert isinstance(result, cp.InstrumentalVariable)
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
with pytest.raises(NotImplementedError):
result.get_plot_data()
assert "gamma_beta_t" in result.model.named_vars
assert "pi_beta_t" in result.model.named_vars


@pytest.mark.integration
def test_inverse_prop(mock_pymc_sample):
"""Test the InversePropensityWeighting class."""
Expand Down
Loading
Loading