15
15
16
16
from collections .abc import Mapping
17
17
from functools import singledispatch
18
- from typing import Dict , Optional , Union
18
+ from typing import Dict , List , Optional , Union
19
19
20
20
import aesara .tensor as at
21
21
import numpy as np
@@ -119,15 +119,15 @@ def _get_scaling(total_size, shape, ndim):
119
119
120
120
121
121
def logpt (
122
- var : TensorVariable ,
122
+ var : Union [ TensorVariable , List [ TensorVariable ]] ,
123
123
rv_values : Optional [Union [TensorVariable , Dict [TensorVariable , TensorVariable ]]] = None ,
124
124
* ,
125
125
jacobian : bool = True ,
126
126
scaling : bool = True ,
127
127
transformed : bool = True ,
128
128
sum : bool = True ,
129
129
** kwargs ,
130
- ) -> TensorVariable :
130
+ ) -> Union [ TensorVariable , List [ TensorVariable ]] :
131
131
"""Create a measure-space (i.e. log-likelihood) graph for a random variable
132
132
or a list of random variables at a given point.
133
133
@@ -154,7 +154,7 @@ def logpt(
154
154
transformed
155
155
Apply transforms.
156
156
sum
157
- Sum the log-likelihood.
157
+ Sum the log-likelihood or return each term as a separate list item .
158
158
159
159
"""
160
160
# TODO: In future when we drop support for tag.value_var most of the following
@@ -241,7 +241,13 @@ def logpt(
241
241
if sum :
242
242
logp_var = at .sum ([at .sum (factor ) for factor in logp_var_dict .values ()])
243
243
else :
244
- logp_var = at .add (* logp_var_dict .values ())
244
+ logp_var = list (logp_var_dict .values ())
245
+ # TODO: deprecate special behavior when only one variable is requested and
246
+ # always return a list. This is here for backwards compatibility as logpt
247
+ # started as a replacement to factor.logpt, but it should now be considered an
248
+ # internal function reached only via model.logp* methods.
249
+ if len (logp_var ) == 1 :
250
+ logp_var = logp_var [0 ]
245
251
246
252
return logp_var
247
253
0 commit comments