Skip to content

Commit 234a0cd

Browse files
authored
Merge pull request #232 from jpreszler/issue_129_docstring_additions
Issue 129: increase docstring coverage
2 parents d7a12cb + c80d78e commit 234a0cd

24 files changed

+1108
-75
lines changed

.github/workflows/ci.yml

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ jobs:
3434
uses: actions/setup-python@v3
3535
with:
3636
python-version: ${{ matrix.python-version }}
37+
- name: Run doctests
38+
run: |
39+
pip install -e .[test]
40+
pytest --doctest-modules causalpy/
3741
- name: Run tests
3842
run: |
3943
pip install -e .[test]

CONTRIBUTING.md

+14
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ We recommend that your contribution complies with the following guidelines befor
121121

122122
- All public methods must have informative docstrings with sample usage when appropriate.
123123

124+
- Example usage in docstrings is tested via doctest, which can be run via
125+
126+
```bash
127+
make doctest
128+
```
129+
130+
- Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `ModelBuilder` class in pymc_models:
131+
132+
```bash
133+
pytest --doctest-modules causalpy/
134+
pytest --doctest-modules causalpy/pymc_models.py
135+
pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.ModelBuilder
136+
```
137+
124138
- To indicate a work in progress please mark the PR as `draft`. Drafts may be useful to (1) indicate you are working on something to avoid duplicated work, (2) request broad review of functionality or API, or (3) seek collaborators.
125139

126140
- All other tests pass when everything is rebuilt from scratch. Tests can be run with:

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ check_lint:
1717
nbqa isort --check-only .
1818
interrogate .
1919

20+
doctest:
21+
pip install causalpy[test]
22+
pytest --doctest-modules causalpy/
23+
2024
test:
2125
pip install causalpy[test]
2226
pytest

causalpy/custom_exceptions.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
"""
2+
Custom Exceptions for CausalPy.
3+
"""
4+
5+
16
class BadIndexException(Exception):
27
"""Custom exception used when we have a mismatch in types between the dataframe
38
index and an event, typically a treatment or intervention."""
49

5-
def __init__(self, message):
10+
def __init__(self, message: str):
611
self.message = message
712

813

914
class FormulaException(Exception):
1015
"""Exception raised given when there is some error in a user-provided model
1116
formula"""
1217

13-
def __init__(self, message):
18+
def __init__(self, message: str):
1419
self.message = message
1520

1621

1722
class DataException(Exception):
1823
"""Exception raised given when there is some error in user-provided dataframe"""
1924

20-
def __init__(self, message):
25+
def __init__(self, message: str):
2126
self.message = message

causalpy/data/datasets.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Functions to load example datasets
3+
"""
14
import pathlib
25

36
import pandas as pd

causalpy/data/simulate_data.py

+92-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Functions that generate data sets used in examples
3+
"""
14
import numpy as np
25
import pandas as pd
36
from scipy.stats import dirichlet, gamma, norm, uniform
@@ -11,6 +14,18 @@
1114
def _smoothed_gaussian_random_walk(
1215
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
1316
):
17+
"""
18+
Generates Gaussian random walk data and applies LOWESS
19+
20+
:param gaussian_random_walk_mu:
21+
Mean of the random walk
22+
:param gaussian_random_walk_sigma:
23+
Standard deviation of the random walk
24+
:param N:
25+
Length of the random walk
26+
:param lowess_kwargs:
27+
Keyword argument dictionary passed to statsmodels lowess
28+
"""
1429
x = np.arange(N)
1530
y = norm(gaussian_random_walk_mu, gaussian_random_walk_sigma).rvs(N).cumsum()
1631
filtered = lowess(y, x, **lowess_kwargs)
@@ -26,12 +41,25 @@ def generate_synthetic_control_data(
2641
lowess_kwargs=default_lowess_kwargs,
2742
):
2843
"""
29-
Example:
30-
>> import pathlib
31-
>> df, weightings_true = generate_synthetic_control_data(
32-
treatment_time=treatment_time
33-
)
34-
>> df.to_csv(pathlib.Path.cwd() / 'synthetic_control.csv', index=False)
44+
Generates data for synthetic control example.
45+
46+
:param N:
47+
Number fo data points
48+
:param treatment_time:
49+
Index where treatment begins in the generated dataframe
50+
:param grw_mu:
51+
Mean of Gaussian Random Walk
52+
:param grw_sigma:
53+
Standard deviation of Gaussian Random Walk
54+
:lowess_kwargs:
55+
Keyword argument dictionary passed to statsmodels lowess
56+
57+
Example
58+
--------
59+
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
60+
>>> df, weightings_true = generate_synthetic_control_data(
61+
... treatment_time=70
62+
... )
3563
"""
3664

3765
# 1. Generate non-treated variables
@@ -70,6 +98,21 @@ def generate_synthetic_control_data(
7098
def generate_time_series_data(
7199
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
72100
):
101+
"""
102+
Generates interrupted time series example data
103+
104+
:param N:
105+
Length of the time series
106+
:param treatment_time:
107+
Index of when treatment begins
108+
:param beta_temp:
109+
The temperature coefficient
110+
:param beta_linear:
111+
The linear coefficient
112+
:param beta_intercept:
113+
The intercept
114+
115+
"""
73116
x = np.arange(0, 100, 1)
74117
df = pd.DataFrame(
75118
{
@@ -99,6 +142,9 @@ def generate_time_series_data(
99142

100143

101144
def generate_time_series_data_seasonal(treatment_time):
145+
"""
146+
Generates 10 years of monthly data with seasonality
147+
"""
102148
dates = pd.date_range(
103149
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
104150
)
@@ -146,6 +192,14 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
146192

147193

148194
def generate_did():
195+
"""
196+
Generate Difference in Differences data
197+
198+
Example
199+
--------
200+
>>> from causalpy.data.simulate_data import generate_did
201+
>>> df = generate_did()
202+
"""
149203
# true parameters
150204
control_intercept = 1
151205
treat_intercept_delta = 0.25
@@ -157,6 +211,7 @@ def generate_did():
157211
def outcome(
158212
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
159213
):
214+
"""Compute the outcome of each unit"""
160215
return (
161216
control_intercept
162217
+ (treat_intercept_delta * group)
@@ -191,16 +246,23 @@ def generate_regression_discontinuity_data(
191246
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
192247
):
193248
"""
194-
Example use:
195-
>> import pathlib
196-
>> df = generate_regression_discontinuity_data(true_treatment_threshold=0.5)
197-
>> df.to_csv(pathlib.Path.cwd() / 'regression_discontinuity.csv', index=False)
249+
Generate regression discontinuity example data
250+
251+
Example
252+
--------
253+
>>> import pathlib
254+
>>> from causalpy.data.simulate_data import generate_regression_discontinuity_data
255+
>>> df = generate_regression_discontinuity_data(true_treatment_threshold=0.5)
256+
>>> df.to_csv(pathlib.Path.cwd() / 'regression_discontinuity.csv',
257+
... index=False) # doctest: +SKIP
198258
"""
199259

200260
def is_treated(x):
261+
"""Check if x was treated"""
201262
return np.greater_equal(x, true_treatment_threshold)
202263

203264
def impact(x):
265+
"""Assign true_causal_impact to all treaated entries"""
204266
y = np.zeros(len(x))
205267
y[is_treated(x)] = true_causal_impact
206268
return y
@@ -214,6 +276,22 @@ def impact(x):
214276
def generate_ancova_data(
215277
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
216278
):
279+
"""
280+
Generate ANCOVA eample data
281+
282+
Example
283+
--------
284+
>>> import pathlib
285+
>>> from causalpy.data.simulate_data import generate_ancova_data
286+
>>> df = generate_ancova_data(
287+
... N=200,
288+
... pre_treatment_means=np.array([10, 12]),
289+
... treatment_effect=2,
290+
... sigma=1
291+
... )
292+
>>> df.to_csv(pathlib.Path.cwd() / 'ancova_data.csv',
293+
... index=False) # doctest: +SKIP
294+
"""
217295
group = np.random.choice(2, size=N)
218296
pre = np.random.normal(loc=pre_treatment_means[group])
219297
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma
@@ -233,6 +311,10 @@ def generate_geolift_data():
233311
causal_impact = 0.2
234312

235313
def create_series(n=52, amplitude=1, length_scale=2):
314+
"""
315+
Returns numpy tile with generated seasonality data repeated over
316+
multiple years
317+
"""
236318
return np.tile(
237319
generate_seasonality(n=n, amplitude=amplitude, length_scale=2) + 3, n_years
238320
)

causalpy/plot_utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Plotting utility functions.
3+
"""
4+
15
from typing import Any, Dict, Optional, Tuple, Union
26

37
import arviz as az
@@ -17,7 +21,22 @@ def plot_xY(
1721
hdi_prob: float = 0.94,
1822
label: Union[str, None] = None,
1923
) -> Tuple[Line2D, PolyCollection]:
20-
"""Utility function to plot HDI intervals."""
24+
"""
25+
Utility function to plot HDI intervals.
26+
27+
:param x:
28+
Pandas datetime index or numpy array of x-axis values
29+
:param y:
30+
Xarray data array of y-axis data
31+
:param ax:
32+
Matplotlib ax object
33+
:param plot_hdi_kwargs:
34+
Dictionary of keyword arguments passed to ax.plot()
35+
:param hdi_prob:
36+
The size of the HDI, default is 0.94
37+
:param label:
38+
The plot label
39+
"""
2140

2241
if plot_hdi_kwargs is None:
2342
plot_hdi_kwargs = {}

0 commit comments

Comments
 (0)