Skip to content

Commit 96ba24b

Browse files
committed
add warning when bandwidth parameter leads to only a few remaining datapoints
1 parent e0a42dc commit 96ba24b

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

causalpy/pymc_experiments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Optional, Union
23

34
import arviz as az
@@ -586,6 +587,11 @@ def __init__(
586587
fmin = self.treatment_threshold - self.bandwidth
587588
fmax = self.treatment_threshold + self.bandwidth
588589
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
590+
if len(filtered_data) <= 10:
591+
warnings.warn(
592+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
593+
UserWarning,
594+
)
589595
y, X = dmatrices(formula, filtered_data)
590596
else:
591597
y, X = dmatrices(formula, self.data)

causalpy/skl_experiments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Optional
23

34
import matplotlib.pyplot as plt
@@ -391,6 +392,11 @@ def __init__(
391392
fmin = self.treatment_threshold - self.bandwidth
392393
fmax = self.treatment_threshold + self.bandwidth
393394
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
395+
if len(filtered_data) <= 10:
396+
warnings.warn(
397+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
398+
UserWarning,
399+
)
394400
y, X = dmatrices(formula, filtered_data)
395401
else:
396402
y, X = dmatrices(formula, self.data)

0 commit comments

Comments
 (0)