Skip to content

Commit 69815d9

Browse files
ricardoV94twiecki
authored andcommitted
Add model.logp_elemswiset
1 parent 79e346d commit 69815d9

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

pymc/model.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,8 @@ def logp(self):
284284
"""Compiled log probability density function"""
285285
return self.model.fn(self.logpt)
286286

287-
@property
288-
def logp_elemwise(self):
289-
return self.model.fn(self.logp_elemwiset)
287+
def logp_elemwise(self, vars=None, jacobian=True):
288+
return self.model.fn(self.logp_elemwiset(vars=vars, jacobian=jacobian))
290289

291290
def dlogp(self, vars=None):
292291
"""Compiled log probability density gradient function"""
@@ -728,6 +727,66 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
728727
}
729728
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
730729

730+
def logp_elemwiset(
731+
self,
732+
vars: Optional[Union[Variable, List[Variable]]] = None,
733+
jacobian: bool = True,
734+
) -> List[Variable]:
735+
"""Elemwise log-probability of the model.
736+
737+
Parameters
738+
----------
739+
vars: list of random variables or potential terms, optional
740+
Compute the gradient with respect to those variables. If None, use all
741+
free and observed random variables, as well as potential terms in model.
742+
jacobian
743+
Whether to include jacobian terms in logprob graph. Defaults to True.
744+
745+
Returns
746+
-------
747+
Elemwise logp terms for ecah requested variable, in the same order of input.
748+
"""
749+
if vars is None:
750+
vars = self.free_RVs + self.observed_RVs + self.potentials
751+
elif not isinstance(vars, (list, tuple)):
752+
vars = [vars]
753+
754+
# We need to separate random variables from potential terms, and remember their
755+
# original order so that we can merge them together in the same order at the end
756+
rv_values = {}
757+
potentials = []
758+
rv_order, potential_order = [], []
759+
for i, var in enumerate(vars):
760+
value_var = self.rvs_to_values.get(var)
761+
if value_var is not None:
762+
rv_values[var] = value_var
763+
rv_order.append(i)
764+
else:
765+
if var in self.potentials:
766+
potentials.append(var)
767+
potential_order.append(i)
768+
else:
769+
raise ValueError(
770+
f"Requested variable {var} not found among the model variables"
771+
)
772+
773+
rv_logps = []
774+
if rv_values:
775+
rv_logps = logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
776+
if not isinstance(rv_logps, list):
777+
rv_logps = [rv_logps]
778+
779+
# Replace random variables by their value variables in potential terms
780+
potential_logps = []
781+
if potentials:
782+
potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True)
783+
784+
logp_elemwise = [None] * len(vars)
785+
for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)):
786+
logp_elemwise[logp_order] = logp
787+
788+
return logp_elemwise
789+
731790
@property
732791
def logpt(self):
733792
"""Aesara scalar of log-probability of the model"""

0 commit comments

Comments
 (0)