Skip to content

Commit 73694ef

Browse files
a-rodinferrine
authored andcommitted
Raise NotImplementedError for SplineWrapper gradient operation (#2211)
* Make spacing more compliant to PEP8 * Raise NotImplementedError from SplineWrapper grad operation Fixes #2209 * Instantiate SplineWrapper recursively to simplify the architecture * Create spline derivatives lazily * Move grad_op to a separate property * Add tests for SplineWrapper * Fix style issues
1 parent 99744e3 commit 73694ef

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

pymc3/distributions/continuous.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from . import transforms
1818
from pymc3.util import get_variable_name
1919

20-
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, DifferentiableSplineWrapper
20+
from .dist_math import (
21+
bound, logpow, gammaln, betaln, std_cdf, i0,
22+
i1, alltrue_elemwise, SplineWrapper
23+
)
2124
from .distribution import Continuous, draw_values, generate_samples, Bound
2225

2326
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
@@ -1625,7 +1628,7 @@ def __init__(self, x_points, pdf_points, transform='interval',
16251628
Z = interp.integral(x_points[0], x_points[-1])
16261629

16271630
self.Z = tt.as_tensor_variable(Z)
1628-
self.interp_op = DifferentiableSplineWrapper(interp)
1631+
self.interp_op = SplineWrapper(interp)
16291632
self.x_points = x_points
16301633
self.pdf_points = pdf_points / Z
16311634
self.cdf_points = interp.antiderivative()(x_points) / Z

pymc3/distributions/dist_math.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def conjugate_solve_triangular(outer, inner):
369369
grad = tt.triu(s + s.T) - tt.diag(tt.diagonal(s))
370370
return [tt.switch(ok, grad, floatX(np.nan))]
371371

372-
class SplineWrapper (theano.Op):
372+
373+
class SplineWrapper(theano.Op):
373374
"""
374375
Creates a theano operation from scipy.interpolate.UnivariateSpline
375376
"""
@@ -381,22 +382,24 @@ class SplineWrapper (theano.Op):
381382
def __init__(self, spline):
382383
self.spline = spline
383384

385+
@property
386+
def grad_op(self):
387+
if not hasattr(self, '_grad_op'):
388+
try:
389+
self._grad_op = SplineWrapper(self.spline.derivative())
390+
except ValueError:
391+
self._grad_op = None
392+
393+
if self._grad_op is None:
394+
raise NotImplementedError('Spline of order 0 is not differentiable')
395+
return self._grad_op
396+
384397
def perform(self, node, inputs, output_storage):
385398
x, = inputs
386399
output_storage[0][0] = np.asarray(self.spline(x))
387400

388-
class DifferentiableSplineWrapper (SplineWrapper):
389-
"""
390-
Creates a theano operation with defined gradient from
391-
scipy.interpolate.UnivariateSpline
392-
"""
393-
394-
def __init__(self, spline):
395-
super(DifferentiableSplineWrapper, self).__init__(spline)
396-
self.spline_grad = SplineWrapper(spline.derivative())
397-
self.__props__ += ('spline_grad',)
398-
399401
def grad(self, inputs, grads):
400402
x, = inputs
401403
x_grad, = grads
402-
return [x_grad * self.spline_grad(x)]
404+
405+
return [x_grad * self.grad_op(x)]

pymc3/tests/test_dist_math.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import theano
55
import theano.tests.unittest_tools as utt
66
import pymc3 as pm
7-
from scipy import stats
7+
from scipy import stats, interpolate
88
import pytest
99

1010
from ..theanof import floatX
1111
from ..distributions import Discrete
1212
from ..distributions.dist_math import (
13-
bound, factln, alltrue_scalar, MvNormalLogp)
13+
bound, factln, alltrue_scalar, MvNormalLogp, SplineWrapper)
1414

1515

1616
def test_bound():
@@ -176,3 +176,20 @@ def test_hessian(self):
176176
logp = MvNormalLogp()(cov, delta)
177177
g_cov, g_delta = tt.grad(logp, [cov, delta])
178178
tt.grad(g_delta.sum() + g_cov.sum(), [delta, cov])
179+
180+
181+
class TestSplineWrapper(object):
182+
def test_grad(self):
183+
x = np.linspace(0, 1, 100)
184+
y = x * x
185+
spline = SplineWrapper(interpolate.InterpolatedUnivariateSpline(x, y, k=1))
186+
utt.verify_grad(spline, [0.5])
187+
188+
def test_hessian(self):
189+
x = np.linspace(0, 1, 100)
190+
y = x * x
191+
spline = SplineWrapper(interpolate.InterpolatedUnivariateSpline(x, y, k=1))
192+
x_var = tt.dscalar('x')
193+
g_x, = tt.grad(spline(x_var), [x_var])
194+
with pytest.raises(NotImplementedError):
195+
tt.grad(g_x, [x_var])

0 commit comments

Comments
 (0)