Skip to content

Commit e684ede

Browse files
committed
Convert to constant and add comments
1 parent c421634 commit e684ede

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

pymc3/distributions/continuous.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -3873,9 +3873,10 @@ class Interpolated(BoundedContinuous):
38733873
Parameters
38743874
----------
38753875
x_points: array-like
3876-
A monotonically growing list of values
3876+
A monotonically growing list of values. Must be non-symbolic
38773877
pdf_points: array-like
3878-
Probability density function evaluated on lattice ``x_points``
3878+
Probability density function evaluated on lattice ``x_points``. Must
3879+
be non-symbolic
38793880
"""
38803881

38813882
rv_op = interpolated
@@ -3889,9 +3890,9 @@ def dist(cls, x_points, pdf_points, *args, **kwargs):
38893890
cdf_points = interp.antiderivative()(x_points) / Z
38903891
pdf_points = pdf_points / Z
38913892

3892-
x_points = at.as_tensor_variable(floatX(x_points))
3893-
pdf_points = at.as_tensor_variable(floatX(pdf_points))
3894-
cdf_points = at.as_tensor_variable(floatX(cdf_points))
3893+
x_points = at.constant(floatX(x_points))
3894+
pdf_points = at.constant(floatX(pdf_points))
3895+
cdf_points = at.constant(floatX(cdf_points))
38953896

38963897
# lower = at.as_tensor_variable(x_points[0])
38973898
# upper = at.as_tensor_variable(x_points[-1])
@@ -3913,11 +3914,14 @@ def logp(value, x_points, pdf_points, cdf_points):
39133914
-------
39143915
TensorVariable
39153916
"""
3917+
# x_points and pdf_points are expected to be non-symbolic arrays wrapped
3918+
# within a tensor.constant. We use the .data method to retrieve them
39163919
interp = InterpolatedUnivariateSpline(x_points.data, pdf_points.data, k=1, ext="zeros")
3917-
interp_op = SplineWrapper(interp)
3918-
39193920
Z = interp.integral(x_points.data[0], x_points.data[-1])
3920-
Z = at.as_tensor_variable(Z)
3921+
3922+
# interp and Z are converted to symbolic variables here
3923+
interp_op = SplineWrapper(interp)
3924+
Z = at.constant(Z)
39213925

39223926
return at.log(interp_op(value) / Z)
39233927

0 commit comments

Comments
 (0)