Skip to content

Commit 5df81f3

Browse files
gjgressGabriel Gress
and
Gabriel Gress
authored
Fixes to Bayes SDE notebook (#449)
* Updated retcode success check to the current version used by SciMLBase. Introduced new noisy observations that are better suited for the problem, and corrected the model to calculate the likelihood based on multiple trajectories rather than a single trajectory. * Added more explanation on the likelihood calculation --------- Co-authored-by: Gabriel Gress <[email protected]>
1 parent d901459 commit 5df81f3

File tree

1 file changed

+19
-8
lines changed
  • tutorials/10-bayesian-stochastic-differential-equations

1 file changed

+19
-8
lines changed

tutorials/10-bayesian-stochastic-differential-equations/index.qmd

+19-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Pkg.instantiate();
1717
```{julia}
1818
using Turing
1919
using DifferentialEquations
20+
using DifferentialEquations.EnsembleAnalysis
2021
2122
# Load StatsPlots for visualizations and diagnostics.
2223
using StatsPlots
@@ -117,6 +118,10 @@ prob_sde = SDEProblem(lotka_volterra!, multiplicative_noise!, u0, tspan, p)
117118
ensembleprob = EnsembleProblem(prob_sde)
118119
data = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories=1000)
119120
plot(EnsembleSummary(data))
121+
122+
# We generate new noisy observations based on the stochastic model for the parameter estimation tasks in this tutorial.
123+
# We create our observations by adding random normally distributed noise to the mean of the ensemble simulation.
124+
sdedata = reduce(hcat, timeseries_steps_mean(data).u) + 0.8 * randn(size(reduce(hcat, timeseries_steps_mean(data).u)))
120125
```
121126

122127
```{julia}
@@ -132,17 +137,24 @@ plot(EnsembleSummary(data))
132137
133138
# Simulate stochastic Lotka-Volterra model.
134139
p = [α, β, γ, δ, ϕ1, ϕ2]
135-
predicted = solve(prob, SOSRI(); p=p, saveat=0.1)
140+
remake(prob, p = p)
141+
ensembleprob = EnsembleProblem(prob)
142+
predicted = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories = 1000)
136143
137144
# Early exit if simulation could not be computed successfully.
138-
if predicted.retcode !== :Success
139-
Turing.@addlogprob! -Inf
140-
return nothing
145+
for i in 1:length(predicted)
146+
if !SciMLBase.successful_retcode(predicted[i])
147+
Turing.@addlogprob! -Inf
148+
return nothing
149+
end
141150
end
142151
143152
# Observations.
144-
for i in 1:length(predicted)
145-
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
153+
# We compute the likelihood for each trajectory of our simulation in order to better approximate the overall likelihood of our choice of parameters
154+
for j in 1:length(predicted)
155+
for i in 1:length(predicted[j])
156+
data[:, i] ~ MvNormal(predicted[j][i], σ^2 * I)
157+
end
146158
end
147159
148160
return nothing
@@ -154,9 +166,8 @@ Therefore we use NUTS with a low target acceptance rate of `0.25` and specify a
154166
SGHMC might be a more suitable algorithm to be used here.
155167

156168
```{julia}
157-
model_sde = fitlv_sde(odedata, prob_sde)
169+
model_sde = fitlv_sde(sdedata, prob_sde)
158170
159-
setadbackend(:forwarddiff)
160171
chain_sde = sample(
161172
model_sde,
162173
NUTS(0.25),

0 commit comments

Comments
 (0)