Skip to content

Commit 4500708

Browse files
authored
Change default_*_config from @Property to @staticmethod (#235)
* Change default_*_config from @Property to @staticmethod Fixes #225. Since the method changed from a property to a method, it was renamed from a noun to a verb. Both `default_model_config` and `default_sampler_config` were changed. * Fix to use staticmethod instead of classmethod
1 parent a9e0be5 commit 4500708

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

pymc_experimental/linearmodel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ def __init__(self, model_config: Dict = None, sampler_config: Dict = None, nsamp
2020
_model_type = "LinearModel"
2121
version = "0.1"
2222

23-
@property
24-
def default_model_config(self):
23+
@staticmethod
24+
def get_default_model_config():
2525
return {
2626
"intercept": {"loc": 0, "scale": 10},
2727
"slope": {"loc": 0, "scale": 10},
2828
"obs_error": 2,
2929
}
3030

31-
@property
32-
def default_sampler_config(self):
31+
@staticmethod
32+
def get_default_sampler_config():
3333
return {
3434
"draws": 1_000,
3535
"tune": 1_000,

pymc_experimental/model_builder.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def __init__(
7171
>>> ...
7272
>>> model = MyModel(model_config, sampler_config)
7373
"""
74-
sampler_config = self.default_sampler_config if sampler_config is None else sampler_config
74+
sampler_config = (
75+
self.get_default_sampler_config() if sampler_config is None else sampler_config
76+
)
7577
self.sampler_config = sampler_config
76-
model_config = self.default_model_config if model_config is None else model_config
78+
model_config = self.get_default_model_config() if model_config is None else model_config
7779

7880
self.model_config = model_config # parameters for priors etc.
7981
self.model = None # Set by build_model
@@ -133,17 +135,17 @@ def output_var(self):
133135
"""
134136
raise NotImplementedError
135137

136-
@property
138+
@staticmethod
137139
@abstractmethod
138-
def default_model_config(self) -> Dict:
140+
def get_default_model_config() -> Dict:
139141
"""
140142
Returns a class default config dict for model builder if no model_config is provided on class initialization
141143
Useful for understanding structure of required model_config to allow its customization by users
142144
Examples
143145
--------
144-
>>> @classmethod
145-
>>> def default_model_config(self):
146-
>>> Return {
146+
>>> @staticmethod
147+
>>> def default_model_config():
148+
>>> return {
147149
>>> 'a' : {
148150
>>> 'loc': 7,
149151
>>> 'scale' : 3
@@ -162,17 +164,17 @@ def default_model_config(self) -> Dict:
162164
"""
163165
raise NotImplementedError
164166

165-
@property
167+
@staticmethod
166168
@abstractmethod
167-
def default_sampler_config(self) -> Dict:
169+
def get_default_sampler_config(self) -> Dict:
168170
"""
169171
Returns a class default sampler dict for model builder if no sampler_config is provided on class initialization
170172
Useful for understanding structure of required sampler_config to allow its customization by users
171173
Examples
172174
--------
173-
>>> @classmethod
174-
>>> def default_sampler_config(self):
175-
>>> Return {
175+
>>> @staticmethod
176+
>>> def default_sampler_config():
177+
>>> return {
176178
>>> 'draws': 1_000,
177179
>>> 'tune': 1_000,
178180
>>> 'chains': 1,

pymc_experimental/tests/test_model_builder.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
7474
self.generate_and_preprocess_model_data(X, y)
7575
with pm.Model(coords=coords) as self.model:
7676
if model_config is None:
77-
model_config = self.default_model_config
77+
model_config = self.model_config
7878
x = pm.MutableData("x", self.X["input"].values)
7979
y_data = pm.MutableData("y_data", self.y)
8080

@@ -114,8 +114,8 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
114114
self.X = X
115115
self.y = y
116116

117-
@property
118-
def default_model_config(self) -> Dict:
117+
@staticmethod
118+
def get_default_model_config() -> Dict:
119119
return {
120120
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
121121
"b": {"loc": 0, "scale": 10},
@@ -128,8 +128,8 @@ def _generate_and_preprocess_model_data(
128128
self.X = X
129129
self.y = y
130130

131-
@property
132-
def default_sampler_config(self) -> Dict:
131+
@staticmethod
132+
def get_default_sampler_config() -> Dict:
133133
return {
134134
"draws": 1_000,
135135
"tune": 1_000,

0 commit comments

Comments
 (0)