Skip to content

Commit 269da5c

Browse files
committed
Allow opting out of model nesting
1 parent b77747d commit 269da5c

File tree

8 files changed

+55
-42
lines changed

8 files changed

+55
-42
lines changed

pymc/model/core.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Literal,
2626
Optional,
2727
TypeVar,
28+
Union,
2829
cast,
2930
overload,
3031
)
@@ -441,7 +442,7 @@ class Model(WithMemoization, metaclass=ContextMeta):
441442
442443
coords = {
443444
"feature", ["A", "B", "C"],
444-
"trial", [1, 2, 3, 4, 5],
445+
"trial", [1, 2, 3, 4, 5],
445446
}
446447
447448
with pm.Model(coords=coords) as model:
@@ -476,6 +477,11 @@ class Model(WithMemoization, metaclass=ContextMeta):
476477
# Variable will belong to root and second
477478
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"
478479
480+
# Set None for standalone model
481+
with pm.Model(name="third", model=None) as third:
482+
# Variable will belong to third only
483+
w = pm.Normal("w") # Variable wil be named "third::w"
484+
479485
480486
Set `check_bounds` to False for models with only continuous variables and default transformers
481487
PyMC will remove the bounds check from the model logp which can speed up sampling
@@ -497,13 +503,13 @@ def __enter__(self: Self) -> Self: ...
497503

498504
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ...
499505

500-
def __new__(cls, *args, **kwargs):
506+
def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
501507
# resolves the parent instance
502508
instance = super().__new__(cls)
503-
if kwargs.get("model") is not None:
504-
instance._parent = kwargs.get("model")
505-
else:
509+
if model is UNSET:
506510
instance._parent = cls.get_context(error_if_none=False)
511+
else:
512+
instance._parent = model
507513
return instance
508514

509515
@staticmethod
@@ -519,9 +525,9 @@ def __init__(
519525
check_bounds=True,
520526
*,
521527
coords_mutable=None,
522-
model=None,
528+
model: Union[Literal[UNSET], None, "Model"] = UNSET,
523529
):
524-
del model # used in __new__
530+
del model # used in __new__ to define the parent of this model
525531
self.name = self._validate_name(name)
526532
self.check_bounds = check_bounds
527533

pymc/model/fgraph.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,7 @@ def first_non_model_var(var):
299299
else:
300300
return var
301301

302-
model = Model()
303-
if model.parent is not None:
304-
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
302+
model = Model(model=None) # Do not inherit from any model in the context manager
305303

306304
_coords = getattr(fgraph, "_coords", {})
307305
_dim_lengths = getattr(fgraph, "_dim_lengths", {})

pymc/model/transform/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor import Variable
1717
from pytensor.graph import ancestors
1818

19-
from pymc import Model
19+
from pymc.model.core import Model
2020
from pymc.model.fgraph import (
2121
ModelObservedRV,
2222
ModelVar,

pymc/model/transform/conditioning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from pytensor.graph import ancestors
2020
from pytensor.tensor import TensorVariable
2121

22-
from pymc import Model
2322
from pymc.logprob.transforms import Transform
2423
from pymc.logprob.utils import rvs_in_graph
24+
from pymc.model.core import Model
2525
from pymc.model.fgraph import (
2626
ModelDeterministic,
2727
ModelFreeRV,

pymc/sampling/deterministic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def compute_deterministics(
8383
model = modelcontext(model)
8484

8585
if var_names is None:
86-
deterministics = model.deterministics
86+
deterministics = list(model.deterministics)
8787
var_names = [det.name for det in deterministics]
8888
else:
8989
deterministics = [model[var_name] for var_name in var_names]

pymc/stats/log_density.py

+15-24
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
__all__ = ("compute_log_likelihood", "compute_log_prior")
2727

28+
from pymc.model.transform.conditioning import remove_value_transforms
29+
2830

2931
def compute_log_likelihood(
3032
idata: InferenceData,
@@ -126,46 +128,35 @@ def compute_log_density(
126128
if kind not in ("likelihood", "prior"):
127129
raise ValueError("kind must be either 'likelihood' or 'prior'")
128130

131+
# We need to disable transforms, because the InferenceData only keeps the untransformed values
132+
umodel = remove_value_transforms(model)
133+
129134
if kind == "likelihood":
130-
target_rvs = model.observed_RVs
135+
target_rvs = list(umodel.observed_RVs)
131136
target_str = "observed_RVs"
132137
else:
133-
target_rvs = model.free_RVs
138+
target_rvs = list(umodel.free_RVs)
134139
target_str = "free_RVs"
135140

136141
if var_names is None:
137142
vars = target_rvs
138143
var_names = tuple(rv.name for rv in vars)
139144
else:
140-
vars = [model.named_vars[name] for name in var_names]
145+
vars = [umodel.named_vars[name] for name in var_names]
141146
if not set(vars).issubset(target_rvs):
142147
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")
143148

144-
# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
145-
try:
146-
original_rvs_to_values = model.rvs_to_values
147-
original_rvs_to_transforms = model.rvs_to_transforms
148-
149-
model.rvs_to_values = {
150-
rv: rv.clone() if rv not in model.observed_RVs else value
151-
for rv, value in model.rvs_to_values.items()
152-
}
153-
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}
154-
155-
elemwise_logdens_fn = model.compile_fn(
156-
inputs=model.value_vars,
157-
outs=model.logp(vars=vars, sum=False),
158-
on_unused_input="ignore",
159-
)
160-
finally:
161-
model.rvs_to_values = original_rvs_to_values
162-
model.rvs_to_transforms = original_rvs_to_transforms
149+
elemwise_logdens_fn = umodel.compile_fn(
150+
inputs=umodel.value_vars,
151+
outs=umodel.logp(vars=vars, sum=False),
152+
on_unused_input="ignore",
153+
)
163154

164-
coords, dims = coords_and_dims_for_inferencedata(model)
155+
coords, dims = coords_and_dims_for_inferencedata(umodel)
165156

166157
logdens_dataset = apply_function_over_dataset(
167158
elemwise_logdens_fn,
168-
posterior[[rv.name for rv in model.free_RVs]],
159+
posterior[[rv.name for rv in umodel.free_RVs]],
169160
output_var_names=var_names,
170161
sample_dims=sample_dims,
171162
dims=dims,

tests/model/test_core.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,20 @@ def test_docstring_example(self):
143143
# Variable will belong to root and second
144144
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"
145145

146+
# Set None for standalone model
147+
with pm.Model(name="third", model=None) as third:
148+
# Variable will belong to third only
149+
w = pm.Normal("w") # Variable wil be named "third::w"
150+
146151
assert x.name == "root::x"
147152
assert y.name == "root::first::y"
148153
assert z.name == "root::second::z"
154+
assert w.name == "third::w"
149155

150156
assert set(root.basic_RVs) == {x, y, z}
151157
assert set(first.basic_RVs) == {y}
152158
assert set(second.basic_RVs) == {z}
159+
assert set(third.basic_RVs) == {w}
153160

154161

155162
class TestNested:
@@ -1106,11 +1113,17 @@ def test_model_parent_set_programmatically():
11061113
y = pm.Normal("y")
11071114

11081115
with model:
1116+
# Default inherits from model
1117+
with pm.Model():
1118+
z_in = pm.Normal("z_in")
1119+
1120+
# Explict None opts out of model context
11091121
with pm.Model(model=None):
1110-
z = pm.Normal("z")
1122+
z_out = pm.Normal("z_out")
11111123

11121124
assert "y" in model.named_vars
1113-
assert "z" in model.named_vars
1125+
assert "z_in" in model.named_vars
1126+
assert "z_out" not in model.named_vars
11141127

11151128

11161129
class TestModelContext:

tests/model/test_fgraph.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,15 @@ def test_context_error():
267267
with pm.Model() as m:
268268
x = pm.Normal("x")
269269

270-
fg = fgraph_from_model(m)
270+
fg, _ = fgraph_from_model(m)
271271

272-
with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
273-
model_from_fgraph(fg)
272+
new_m = model_from_fgraph(fg)
273+
new_x = new_m["x"]
274+
275+
assert new_m.parent is None
276+
assert x != new_x
277+
assert m.named_vars == {"x": x}
278+
assert new_m.named_vars == {"x": new_x}
274279

275280

276281
def test_sub_model_error():

0 commit comments

Comments
 (0)