Skip to content

Commit 1742d44

Browse files
committed
Allow linear model formula to extract variables from calling scope.
1 parent f541c5a commit 1742d44

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

pymc3/glm/linear.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ def __init__(self, x, y, intercept=True, labels=None,
8484

8585
@classmethod
8686
def from_formula(cls, formula, data, priors=None, vars=None,
87-
name='', model=None, offset=0.):
87+
name='', model=None, offset=0., eval_env=0):
8888
import patsy
89-
y, x = patsy.dmatrices(formula, data)
89+
eval_env = patsy.EvalEnvironment.capture(eval_env, reference=1)
90+
y, x = patsy.dmatrices(formula, data, eval_env=eval_env)
9091
labels = x.design_info.column_names
9192
return cls(np.asarray(x), np.asarray(y)[:, -1], intercept=False,
9293
labels=labels, priors=priors, vars=vars, name=name,
@@ -140,9 +141,10 @@ def __init__(self, x, y, intercept=True, labels=None,
140141
@classmethod
141142
def from_formula(cls, formula, data, priors=None,
142143
vars=None, family='normal', name='',
143-
model=None, offset=0.):
144+
model=None, offset=0., eval_env=0):
144145
import patsy
145-
y, x = patsy.dmatrices(formula, data)
146+
eval_env = patsy.EvalEnvironment.capture(eval_env, reference=1)
147+
y, x = patsy.dmatrices(formula, data, eval_env=eval_env)
146148
labels = x.design_info.column_names
147149
return cls(np.asarray(x), np.asarray(y)[:, -1], intercept=False,
148150
labels=labels, priors=priors, vars=vars, family=family,

pymc3/tests/test_glm.py

+13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numpy.testing import assert_equal
33

44
from .helpers import SeededTest
5+
import pymc3
56
from pymc3 import Model, Uniform, Normal, find_MAP, Slice, sample
67
from pymc3 import families, GLM, LinearComponent
78
import pandas as pd
@@ -117,3 +118,15 @@ def test_boolean_y(self):
117118
)
118119
)
119120
assert_equal(model.y.observations, model_bool.y.observations)
121+
122+
def test_glm_formula_from_calling_scope(self):
123+
"""Formula can extract variables from the calling scope."""
124+
z = pd.Series([10, 20, 30])
125+
df = pd.DataFrame({"y": [0, 1, 0], "x": [1.0, 2.0, 3.0]})
126+
GLM.from_formula("y ~ x + z", df, family=pymc3.glm.families.Binomial())
127+
128+
def test_linear_component_formula_from_calling_scope(self):
129+
"""Formula can extract variables from the calling scope."""
130+
z = pd.Series([10, 20, 30])
131+
df = pd.DataFrame({"y": [0, 1, 0], "x": [1.0, 2.0, 3.0]})
132+
LinearComponent.from_formula("y ~ x + z", df)

0 commit comments

Comments
 (0)