Skip to content

Commit 604f9fb

Browse files
ricardoV94twiecki
authored andcommitted
Use observation as value_var of observed rvs
1 parent 9705a8a commit 604f9fb

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

pymc/model.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ def make_obs_var(
12751275

12761276
observed_rv_var.tag.observations = nonmissing_data
12771277

1278-
self.create_value_var(observed_rv_var, transform)
1278+
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
12791279
self.add_random_variable(observed_rv_var, dims)
12801280
self.observed_RVs.append(observed_rv_var)
12811281

@@ -1285,22 +1285,21 @@ def make_obs_var(
12851285
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
12861286
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
12871287

1288-
elif sps.issparse(data):
1289-
data = sparse.basic.as_sparse(data, name=name)
1290-
rv_var.tag.observations = data
1291-
self.create_value_var(rv_var, transform)
1292-
self.add_random_variable(rv_var, dims)
1293-
self.observed_RVs.append(rv_var)
12941288
else:
1295-
data = at.as_tensor_variable(data, name=name)
1289+
if sps.issparse(data):
1290+
data = sparse.basic.as_sparse(data, name=name)
1291+
else:
1292+
data = at.as_tensor_variable(data, name=name)
12961293
rv_var.tag.observations = data
1297-
self.create_value_var(rv_var, transform)
1294+
self.create_value_var(rv_var, transform=None, value_var=data)
12981295
self.add_random_variable(rv_var, dims)
12991296
self.observed_RVs.append(rv_var)
13001297

13011298
return rv_var
13021299

1303-
def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVariable:
1300+
def create_value_var(
1301+
self, rv_var: TensorVariable, transform: Any, value_var: Optional[Variable] = None
1302+
) -> TensorVariable:
13041303
"""Create a ``TensorVariable`` that will be used as the random
13051304
variable's "value" in log-likelihood graphs.
13061305
@@ -1311,13 +1310,13 @@ def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVari
13111310
this branch of the conditional.
13121311
13131312
"""
1314-
value_var = rv_var.type()
1313+
if value_var is None:
1314+
value_var = rv_var.type()
1315+
value_var.name = rv_var.name
13151316

13161317
if aesara.config.compute_test_value != "off":
13171318
value_var.tag.test_value = rv_var.tag.test_value
13181319

1319-
value_var.name = rv_var.name
1320-
13211320
rv_var.tag.value_var = value_var
13221321

13231322
# Make the value variable a transformed value variable,

pymc/tests/test_missing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
from numpy import array, ma
2121

22-
from pymc.distributions.continuous import Gamma, Normal, Uniform
23-
from pymc.distributions.transforms import interval
22+
from pymc.distributions import Gamma, Normal, Uniform
2423
from pymc.exceptions import ImputationWarning
2524
from pymc.model import Model
2625
from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive
@@ -94,10 +93,10 @@ def test_interval_missing_observations():
9493
with pytest.warns(ImputationWarning):
9594
theta2 = Normal("theta2", mu=theta1, observed=obs2, rng=rng)
9695

97-
assert "theta1_observed_interval__" in model.named_vars
96+
assert "theta1_observed" in model.named_vars
9897
assert "theta1_missing_interval__" in model.named_vars
99-
assert isinstance(
100-
model.rvs_to_values[model.named_vars["theta1_observed"]].tag.transform, interval
98+
assert not hasattr(
99+
model.rvs_to_values[model.named_vars["theta1_observed"]].tag, "transform"
101100
)
102101

103102
prior_trace = sample_prior_predictive(return_inferencedata=False)

pymc/tests/test_smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,12 @@ def test_multiple_simulators(self):
409409
a_val = m.rvs_to_values[a]
410410
sim1_val = m.rvs_to_values[sim1]
411411
logp_sim1 = pm.logpt(sim1, sim1_val)
412-
logp_sim1_fn = aesara.function([sim1_val, a_val], logp_sim1)
412+
logp_sim1_fn = aesara.function([a_val], logp_sim1)
413413

414414
b_val = m.rvs_to_values[b]
415415
sim2_val = m.rvs_to_values[sim2]
416416
logp_sim2 = pm.logpt(sim2, sim2_val)
417-
logp_sim2_fn = aesara.function([sim2_val, b_val], logp_sim2)
417+
logp_sim2_fn = aesara.function([b_val], logp_sim2)
418418

419419
assert any(
420420
node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)

0 commit comments

Comments
 (0)