Skip to content

Commit 79dcea4

Browse files
authored
Merge branch 'master' into matmul
2 parents b370d11 + e3b667c commit 79dcea4

17 files changed

+1470
-295
lines changed

.travis.yml

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ env:
3535

3636
script:
3737
- . ./scripts/test.sh $TESTCMD
38-
- . ./scripts/confirm_mpl_optional.sh
3938

4039
after_success:
4140
- coveralls

RELEASE-NOTES.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
## PyMC3 3.8 (on deck)
44

55
### New features
6-
6+
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
77
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
88
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
99
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
1010
- Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547).
1111
- Sampling from variational approximation now allows for alternative trace backends [#3550].
12-
- Infix `@` operator now works with random variables and deterministics [#3578](https://github.com/pymc-devs/pymc3/pull/3578).
12+
- Infix `@` operator now works with random variables and deterministics [#3619](https://github.com/pymc-devs/pymc3/pull/3619).
13+
- [ArviZ](https://arviz-devs.github.io/arviz/) is now a requirement, and handles plotting, diagnostics, and statistical checks.
1314

1415
### Maintenance
1516
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)

docs/source/notebooks/ODE_API_parameter_estimation.ipynb

+570
Large diffs are not rendered by default.

docs/source/notebooks/ODE_parameter_estimation.ipynb

+99-197
Large diffs are not rendered by default.

docs/source/notebooks/table_of_contents_examples.js

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ Gallery.contents = {
5353
"normalizing_flows_overview": "Variational Inference",
5454
"gaussian-mixture-model-advi": "Variational Inference",
5555
"GLM-hierarchical-advi-minibatch": "Variational Inference",
56-
"ODE_parameter_estimation": "Inference in ODE models"
56+
"ODE_parameter_estimation": "Inference in ODE models",
57+
"ODE_API_parameter_estimation": "Inference in ODE models with DifferentialEquation"
5758
}

pymc3/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
99
from .model import *
1010
from .model_graph import model_to_graphviz
11+
from . import ode
1112
from .stats import *
1213
from .sampling import *
1314
from .step_methods import *

pymc3/backends/base.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import itertools as itl
77
import logging
8+
from typing import List
89

910
import numpy as np
1011
import warnings
@@ -92,7 +93,8 @@ def _set_sampler_vars(self, sampler_vars):
9293

9394
self.sampler_vars = sampler_vars
9495

95-
def setup(self, draws, chain, sampler_vars=None):
96+
# pylint: disable=unused-argument
97+
def setup(self, draws, chain, sampler_vars=None) -> None:
9698
"""Perform chain-specific setup.
9799
98100
Parameters
@@ -542,7 +544,7 @@ def points(self, chains=None):
542544
return itl.chain.from_iterable(self._straces[chain] for chain in chains)
543545

544546

545-
def merge_traces(mtraces):
547+
def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
546548
"""Merge MultiTrace objects.
547549
548550
Parameters
@@ -552,17 +554,26 @@ def merge_traces(mtraces):
552554
553555
Raises
554556
------
555-
A ValueError is raised if any traces have overlapping chain numbers.
557+
A ValueError is raised if any traces have overlapping chain numbers,
558+
or if chains are of different lengths.
556559
557560
Returns
558561
-------
559562
A MultiTrace instance with merged chains
560563
"""
564+
if len(mtraces) == 0:
565+
raise ValueError("Cannot merge an empty set of traces.")
561566
base_mtrace = mtraces[0]
567+
chain_len = len(base_mtrace)
568+
# check base trace
569+
if any(len(st) != chain_len for _, st in base_mtrace._straces.items()): # pylint: disable=line-too-long
570+
raise ValueError("Chains are of different lengths.")
562571
for new_mtrace in mtraces[1:]:
563572
for new_chain, strace in new_mtrace._straces.items():
564573
if new_chain in base_mtrace._straces:
565574
raise ValueError("Chains are not unique.")
575+
if len(strace) != chain_len:
576+
raise ValueError("Chains are of different lengths.")
566577
base_mtrace._straces[new_chain] = strace
567578
base_mtrace._report = merge_reports([trace.report for trace in mtraces])
568579
return base_mtrace

pymc3/ode/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import utils
2+
from .ode import DifferentialEquation

pymc3/ode/ode.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import numpy as np
2+
import scipy
3+
import theano
4+
import theano.tensor as tt
5+
from ..ode.utils import augment_system, ODEGradop
6+
7+
8+
class DifferentialEquation(theano.Op):
9+
"""
10+
Specify an ordinary differential equation
11+
12+
.. math::
13+
\dfrac{dy}{dt} = f(y,t,p) \quad y(t_0) = y_0
14+
15+
Parameters
16+
----------
17+
18+
func : callable
19+
Function specifying the differential equation
20+
t0 : float
21+
Time corresponding to the initial condition
22+
times : array
23+
Array of times at which to evaluate the solution of the differential equation.
24+
n_states : int
25+
Dimension of the differential equation. For scalar differential equations, n_states=1.
26+
For vector valued differential equations, n_states = number of differential equations in the system.
27+
n_odeparams : int
28+
Number of parameters in the differential equation.
29+
30+
.. code-block:: python
31+
32+
def odefunc(y, t, p):
33+
#Logistic differential equation
34+
return p[0] * y[0] * (1 - y[0])
35+
36+
times = np.arange(0.5, 5, 0.5)
37+
38+
ode_model = DifferentialEquation(func=odefunc, t0=0, times=times, n_states=1, n_odeparams=1)
39+
"""
40+
41+
__props__ = ("func", "t0", "times", "n_states", "n_odeparams")
42+
43+
def __init__(self, func, times, n_states, n_odeparams, t0=0):
44+
if not callable(func):
45+
raise ValueError("Argument func must be callable.")
46+
if n_states < 1:
47+
raise ValueError("Argument n_states must be at least 1.")
48+
if n_odeparams <= 0:
49+
raise ValueError("Argument n_odeparams must be positive.")
50+
51+
# Public
52+
self.func = func
53+
self.t0 = t0
54+
self.times = tuple(times)
55+
self.n_states = n_states
56+
self.n_odeparams = n_odeparams
57+
58+
# Private
59+
self._n = n_states
60+
self._m = n_odeparams + n_states
61+
62+
self._augmented_times = np.insert(times, 0, t0)
63+
self._augmented_func = augment_system(func, self._n, self._m)
64+
self._sens_ic = self._make_sens_ic()
65+
66+
self._cached_y = None
67+
self._cached_sens = None
68+
self._cached_parameters = None
69+
70+
self._grad_op = ODEGradop(self._numpy_vsp)
71+
72+
def _make_sens_ic(self):
73+
"""
74+
The sensitivity matrix will always have consistent form.
75+
If the first n_odeparams entries of the parameters vector in the simulate call
76+
correspond to ode paramaters, then the first n_odeparams columns in
77+
the sensitivity matrix will be 0
78+
79+
If the last n_states entries of the paramters vector in the simulate call
80+
correspond to initial conditions of the system,
81+
then the last n_states columns of the sensitivity matrix should form
82+
an identity matrix
83+
"""
84+
85+
# Initialize the sensitivity matrix to be 0 everywhere
86+
sens_matrix = np.zeros((self._n, self._m))
87+
88+
# Slip in the identity matrix in the appropirate place
89+
sens_matrix[:, -self.n_states :] = np.eye(self.n_states)
90+
91+
# We need the sensitivity matrix to be a vector (see augmented_function)
92+
# Ravel and return
93+
dydp = sens_matrix.ravel()
94+
95+
return dydp
96+
97+
def _system(self, Y, t, p):
98+
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities
99+
100+
"""
101+
102+
dydt, ddt_dydp = self._augmented_func(Y[: self._n], t, p, Y[self._n :])
103+
derivatives = np.concatenate([dydt, ddt_dydp])
104+
return derivatives
105+
106+
def _simulate(self, parameters):
107+
# Initial condition comprised of state initial conditions and raveled
108+
# sensitivity matrix
109+
y0 = np.concatenate([parameters[self.n_odeparams :], self._sens_ic])
110+
111+
# perform the integration
112+
sol = scipy.integrate.odeint(
113+
func=self._system, y0=y0, t=self._augmented_times, args=(parameters,)
114+
)
115+
# The solution
116+
y = sol[1:, : self.n_states]
117+
118+
# The sensitivities, reshaped to be a sequence of matrices
119+
sens = sol[1:, self.n_states :].reshape(len(self.times), self._n, self._m)
120+
121+
return y, sens
122+
123+
def _cached_simulate(self, parameters):
124+
if np.array_equal(np.array(parameters), self._cached_parameters):
125+
126+
return self._cached_y, self._cached_sens
127+
128+
return self._simulate(np.array(parameters))
129+
130+
def _state(self, parameters):
131+
y, sens = self._cached_simulate(np.array(parameters))
132+
self._cached_y, self._cached_sens, self._cached_parameters = y, sens, parameters
133+
return y.ravel()
134+
135+
def _numpy_vsp(self, parameters, g):
136+
_, sens = self._cached_simulate(np.array(parameters))
137+
138+
# Each element of sens is an nxm sensitivity matrix
139+
# There is one sensitivity matrix per time step, making sens a (len(times), n_states, len(parameter))
140+
# dimensional array. Reshaping the sens array in this way is like stacking each of the elements of sens on top
141+
# of one another.
142+
numpy_sens = sens.reshape((self.n_states * len(self.times), len(parameters)))
143+
# The dot product here is equivalent to np.einsum('ijk,jk', sens, g)
144+
# if sens was not reshaped and if g had the same shape as yobs
145+
return numpy_sens.T.dot(g)
146+
147+
def make_node(self, odeparams, y0):
148+
if len(odeparams) != self.n_odeparams:
149+
raise ValueError(
150+
"odeparams has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
151+
a=self.n_odeparams, b=len(odeparams)
152+
)
153+
)
154+
if len(y0) != self.n_states:
155+
raise ValueError(
156+
"y0 has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
157+
a=self.n_states, b=len(y0)
158+
)
159+
)
160+
161+
if np.ndim(odeparams) > 1:
162+
odeparams = np.ravel(odeparams)
163+
if np.ndim(y0) > 1:
164+
y0 = np.ravel(y0)
165+
166+
odeparams = tt.as_tensor_variable(odeparams)
167+
y0 = tt.as_tensor_variable(y0)
168+
parameters = tt.concatenate([odeparams, y0])
169+
return theano.Apply(self, [parameters], [parameters.type()])
170+
171+
def perform(self, node, inputs_storage, output_storage):
172+
parameters = inputs_storage[0]
173+
out = output_storage[0]
174+
# get the numerical solution of ODE states
175+
out[0] = self._state(parameters)
176+
177+
def grad(self, inputs, output_grads):
178+
x = inputs[0]
179+
g = output_grads[0]
180+
# pass the VSP when asked for gradient
181+
grad_op_apply = self._grad_op(x, g)
182+
183+
return [grad_op_apply]

pymc3/ode/utils.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numpy as np
2+
import theano
3+
import theano.tensor as tt
4+
5+
6+
def augment_system(ode_func, n, m):
7+
"""
8+
Function to create augmented system.
9+
10+
Take a function which specifies a set of differential equations and return
11+
a compiled function which allows for computation of gradients of the
12+
differential equation's solition with repsect to the parameters.
13+
14+
Parameters
15+
----------
16+
ode_func : function
17+
Differential equation. Returns array-like.
18+
n : int
19+
Number of rows of the sensitivity matrix.
20+
m : int
21+
Number of columns of the sensitivity matrix.
22+
23+
Returns
24+
-------
25+
system : function
26+
Augemted system of differential equations.
27+
"""
28+
29+
# Present state of the system
30+
t_y = tt.vector("y", dtype=theano.config.floatX)
31+
t_y.tag.test_value = np.zeros((n,))
32+
# Parameter(s). Should be vector to allow for generaliztion to multiparameter
33+
# systems of ODEs. Is m dimensional because it includes all ode parameters as well as initical conditions
34+
t_p = tt.vector("p", dtype=theano.config.floatX)
35+
t_p.tag.test_value = np.zeros((m,))
36+
# Time. Allow for non-automonous systems of ODEs to be analyzed
37+
t_t = tt.scalar("t", dtype=theano.config.floatX)
38+
t_t.tag.test_value = 2.459
39+
40+
# Present state of the gradients:
41+
# Will always be 0 unless the parameter is the inital condition
42+
# Entry i,j is partial of y[i] wrt to p[j]
43+
dydp_vec = tt.vector("dydp", dtype=theano.config.floatX)
44+
dydp_vec.tag.test_value = np.zeros(n * m)
45+
46+
dydp = dydp_vec.reshape((n, m))
47+
48+
# Stack the results of the ode_func
49+
f_tensor = tt.stack(ode_func(t_y, t_t, t_p))
50+
51+
# Now compute gradients
52+
J = tt.jacobian(f_tensor, t_y)
53+
54+
Jdfdy = tt.dot(J, dydp)
55+
56+
grad_f = tt.jacobian(f_tensor, t_p)
57+
58+
# This is the time derivative of dydp
59+
ddt_dydp = (Jdfdy + grad_f).flatten()
60+
61+
system = theano.function(
62+
inputs=[t_y, t_t, t_p, dydp_vec],
63+
outputs=[f_tensor, ddt_dydp],
64+
on_unused_input="ignore",
65+
)
66+
67+
return system
68+
69+
70+
class ODEGradop(theano.Op):
71+
def __init__(self, numpy_vsp):
72+
self._numpy_vsp = numpy_vsp
73+
74+
def make_node(self, x, g):
75+
76+
x = theano.tensor.as_tensor_variable(x)
77+
g = theano.tensor.as_tensor_variable(g)
78+
node = theano.Apply(self, [x, g], [g.type()])
79+
return node
80+
81+
def perform(self, node, inputs_storage, output_storage):
82+
x = inputs_storage[0]
83+
g = inputs_storage[1]
84+
out = output_storage[0]
85+
out[0] = self._numpy_vsp(x, g) # get the numerical VSP

pymc3/plots/__init__.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,8 @@
77
import functools
88
import sys
99
import warnings
10-
try:
11-
import arviz as az
12-
except ImportError: # arviz is optional, throw exception when used
1310

14-
class _ImportWarner:
15-
__all__ = []
16-
17-
def __init__(self, attr):
18-
self.attr = attr
19-
20-
def __call__(self, *args, **kwargs):
21-
raise ImportError(
22-
"ArviZ is not installed. In order to use `{0.attr}`:\npip install arviz".format(self)
23-
)
24-
25-
class _ArviZ:
26-
def __getattr__(self, attr):
27-
return _ImportWarner(attr)
28-
29-
30-
az = _ArviZ()
11+
import arviz as az
3112

3213
def map_args(func):
3314
swaps = [

0 commit comments

Comments
 (0)