Skip to content

Commit 79e346d

Browse files
ricardoV94twiecki
authored andcommitted
Return separate logp terms when logpt is called with sum==False
1 parent a44515c commit 79e346d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

pymc/distributions/logprob.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from collections.abc import Mapping
1717
from functools import singledispatch
18-
from typing import Dict, Optional, Union
18+
from typing import Dict, List, Optional, Union
1919

2020
import aesara.tensor as at
2121
import numpy as np
@@ -119,15 +119,15 @@ def _get_scaling(total_size, shape, ndim):
119119

120120

121121
def logpt(
122-
var: TensorVariable,
122+
var: Union[TensorVariable, List[TensorVariable]],
123123
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
124124
*,
125125
jacobian: bool = True,
126126
scaling: bool = True,
127127
transformed: bool = True,
128128
sum: bool = True,
129129
**kwargs,
130-
) -> TensorVariable:
130+
) -> Union[TensorVariable, List[TensorVariable]]:
131131
"""Create a measure-space (i.e. log-likelihood) graph for a random variable
132132
or a list of random variables at a given point.
133133
@@ -154,7 +154,7 @@ def logpt(
154154
transformed
155155
Apply transforms.
156156
sum
157-
Sum the log-likelihood.
157+
Sum the log-likelihood or return each term as a separate list item.
158158
159159
"""
160160
# TODO: In future when we drop support for tag.value_var most of the following
@@ -241,7 +241,13 @@ def logpt(
241241
if sum:
242242
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
243243
else:
244-
logp_var = at.add(*logp_var_dict.values())
244+
logp_var = list(logp_var_dict.values())
245+
# TODO: deprecate special behavior when only one variable is requested and
246+
# always return a list. This is here for backwards compatibility as logpt
247+
# started as a replacement to factor.logpt, but it should now be considered an
248+
# internal function reached only via model.logp* methods.
249+
if len(logp_var) == 1:
250+
logp_var = logp_var[0]
245251

246252
return logp_var
247253

pymc/tests/test_logprob.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def test_logpt_subtensor():
144144
I_value_var = I_rv.type()
145145
I_value_var.name = "I_value"
146146

147-
A_idx_logp = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
147+
A_idx_logps = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
148+
A_idx_logp = at.add(*A_idx_logps)
148149

149150
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)
150151

0 commit comments

Comments
 (0)