Skip to content

Commit 124313a

Browse files
johapauJohannssenHenrZumknaranja
authored
676 computation of the reproduction number in the seir model (#685)
Co-authored-by: Johannssen <[email protected]> Co-authored-by: HenrZu <[email protected]> Co-authored-by: Henrik Zunker <[email protected]> Co-authored-by: Martin Kühn <[email protected]>
1 parent e3076c2 commit 124313a

File tree

2 files changed

+230
-2
lines changed

2 files changed

+230
-2
lines changed

cpp/models/ode_seir/model.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@
2121
#define SEIR_MODEL_H
2222

2323
#include "memilio/compartments/compartmentalmodel.h"
24+
#include "memilio/config.h"
2425
#include "memilio/epidemiology/populations.h"
2526
#include "memilio/epidemiology/contact_matrix.h"
27+
#include "memilio/io/io.h"
28+
#include "memilio/math/interpolation.h"
29+
#include "memilio/utils/time_series.h"
2630
#include "ode_seir/infection_state.h"
2731
#include "ode_seir/parameters.h"
32+
#include <algorithm>
33+
#include <iterator>
2834

2935
namespace mio
3036
{
@@ -63,6 +69,75 @@ class Model : public CompartmentalModel<InfectionState, Populations<InfectionSta
6369
dydt[(size_t)InfectionState::Recovered] =
6470
(1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected];
6571
}
72+
73+
/**
74+
*@brief Computes the reproduction number at a given index time of the Model output obtained by the Simulation.
75+
*@param t_idx The index time at which the reproduction number is computed.
76+
*@param y The TimeSeries obtained from the Model Simulation.
77+
*@returns The computed reproduction number at the provided index time.
78+
*/
79+
IOResult<ScalarType> get_reproduction_number(size_t t_idx, const mio::TimeSeries<ScalarType>& y)
80+
{
81+
if (!(t_idx < static_cast<size_t>(y.get_num_time_points()))) {
82+
return mio::failure(mio::StatusCode::OutOfRange, "t_idx is not a valid index for the TimeSeries");
83+
}
84+
85+
ScalarType TimeInfected = this->parameters.get<mio::oseir::TimeInfected>();
86+
87+
ScalarType coeffStoE = this->parameters.get<mio::oseir::ContactPatterns>().get_matrix_at(
88+
y.get_time(static_cast<Eigen::Index>(t_idx)))(0, 0) *
89+
this->parameters.get<mio::oseir::TransmissionProbabilityOnContact>() /
90+
this->populations.get_total();
91+
92+
ScalarType result =
93+
y.get_value(static_cast<Eigen::Index>(t_idx))[(Eigen::Index)mio::oseir::InfectionState::Susceptible] *
94+
TimeInfected * coeffStoE;
95+
96+
return mio::success(result);
97+
}
98+
99+
/**
100+
*@brief Computes the reproduction number for all time points of the Model output obtained by the Simulation.
101+
*@param y The TimeSeries obtained from the Model Simulation.
102+
*@returns vector containing all reproduction numbers
103+
*/
104+
Eigen::VectorXd get_reproduction_numbers(const mio::TimeSeries<ScalarType>& y)
105+
{
106+
auto num_time_points = y.get_num_time_points();
107+
Eigen::VectorXd temp(num_time_points);
108+
for (size_t i = 0; i < static_cast<size_t>(num_time_points); i++) {
109+
temp[i] = get_reproduction_number(i, y).value();
110+
}
111+
return temp;
112+
}
113+
114+
/**
115+
*@brief Computes the reproduction number at a given time point of the Model output obtained by the Simulation. If the particular time point is not inside the output, a linearly interpolated value is returned.
116+
*@param t_value The time point at which the reproduction number is computed.
117+
*@param y The TimeSeries obtained from the Model Simulation.
118+
*@returns The computed reproduction number at the provided time point, potentially using linear interpolation.
119+
*/
120+
IOResult<ScalarType> get_reproduction_number(ScalarType t_value, const mio::TimeSeries<ScalarType>& y)
121+
{
122+
if (t_value < y.get_time(0) || t_value > y.get_last_time()) {
123+
return mio::failure(mio::StatusCode::OutOfRange,
124+
"Cannot interpolate reproduction number outside computed horizon of the TimeSeries");
125+
}
126+
127+
if (t_value == y.get_time(0)) {
128+
return mio::success(get_reproduction_number((size_t)0, y).value());
129+
}
130+
131+
auto times = std::vector<ScalarType>(y.get_times().begin(), y.get_times().end());
132+
133+
auto time_late = std::distance(times.begin(), std::lower_bound(times.begin(), times.end(), t_value));
134+
135+
ScalarType y1 = get_reproduction_number(static_cast<size_t>(time_late - 1), y).value();
136+
ScalarType y2 = get_reproduction_number(static_cast<size_t>(time_late), y).value();
137+
138+
auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
139+
return mio::success(static_cast<ScalarType>(result));
140+
}
66141
};
67142

68143
} // namespace oseir

cpp/tests/test_odeseir.cpp

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
* limitations under the License.
1919
*/
2020
#include "load_test_data.h"
21+
#include "memilio/config.h"
22+
#include "memilio/utils/time_series.h"
2123
#include "ode_seir/model.h"
2224
#include "ode_seir/infection_state.h"
2325
#include "ode_seir/parameters.h"
2426
#include "memilio/math/euler.h"
2527
#include "memilio/compartments/simulation.h"
2628
#include <gtest/gtest.h>
29+
#include <iomanip>
30+
#include <vector>
2731

2832
TEST(TestSeir, simulateDefault)
2933
{
@@ -150,7 +154,6 @@ TEST(TestSeir, check_constraints_parameters)
150154
model.parameters.set<mio::oseir::TimeInfected>(6);
151155
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(10.);
152156
ASSERT_EQ(model.parameters.check_constraints(), 1);
153-
154157
mio::set_log_level(mio::LogLevel::warn);
155158
}
156159

@@ -176,6 +179,156 @@ TEST(TestSeir, apply_constraints_parameters)
176179
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(10.);
177180
EXPECT_EQ(model.parameters.apply_constraints(), 1);
178181
EXPECT_NEAR(model.parameters.get<mio::oseir::TransmissionProbabilityOnContact>(), 0.0, 1e-14);
179-
180182
mio::set_log_level(mio::LogLevel::warn);
181183
}
184+
185+
TEST(TestSeir, get_reproduction_numbers)
186+
{
187+
mio::oseir::Model model;
188+
189+
double total_population = 10000;
190+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] = 100;
191+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] = 100;
192+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}] = 100;
193+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Susceptible)}] =
194+
total_population -
195+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] -
196+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] -
197+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}];
198+
199+
model.parameters.set<mio::oseir::TimeInfected>(6);
200+
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(0.04);
201+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 10;
202+
203+
model.apply_constraints();
204+
205+
Eigen::VectorXd checkReproductionNumbers(7);
206+
checkReproductionNumbers << 2.3280000000000002913, 2.3279906878991880603, 2.3279487809434575851,
207+
2.3277601483151548756, 2.3269102025388899158, 2.3230580052413736247, 2.3185400624683065729;
208+
209+
Eigen::VectorXd checkReproductionNumbers2(7);
210+
checkReproductionNumbers2 << 2.0952000000000001734, 2.0951916191092689878, 2.0951539028491117378,
211+
2.0949841334836394324, 2.0942191822850007021, 2.0907522047172362178, 2.086686056221475738;
212+
213+
Eigen::VectorXd checkReproductionNumbers3(7);
214+
checkReproductionNumbers3 << 1.8623999999999998334, 1.8623925503193501374, 1.8623590247547658905,
215+
1.8622081186521235452, 1.8615281620311117106, 1.8584464041930985889, 1.854832049974644903;
216+
217+
mio::TimeSeries<ScalarType> result((int)mio::oseir::InfectionState::Count);
218+
mio::TimeSeries<ScalarType>::Vector result_0(4);
219+
mio::TimeSeries<ScalarType>::Vector result_1(4);
220+
mio::TimeSeries<ScalarType>::Vector result_2(4);
221+
mio::TimeSeries<ScalarType>::Vector result_3(4);
222+
mio::TimeSeries<ScalarType>::Vector result_4(4);
223+
mio::TimeSeries<ScalarType>::Vector result_5(4);
224+
mio::TimeSeries<ScalarType>::Vector result_6(4);
225+
226+
result_0[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9700;
227+
result_1[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.9611995799496071;
228+
result_2[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.7865872644051706;
229+
result_3[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.0006179798110679;
230+
result_4[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9695.4591772453732119;
231+
result_5[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9679.4083551723888377;
232+
result_6[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9660.5835936179428245;
233+
234+
result.add_time_point(0, result_0);
235+
result.add_time_point(0.0010000000000000000208, result_1);
236+
result.add_time_point(0.0055000000000000005482, result_2);
237+
result.add_time_point(0.025750000000000005523, result_3);
238+
result.add_time_point(0.11687500000000002054, result_4);
239+
result.add_time_point(0.52693750000000005862, result_5);
240+
result.add_time_point(1, result_6);
241+
242+
auto reproduction_numbers = model.get_reproduction_numbers(result);
243+
244+
for (int i = 0; i < reproduction_numbers.size(); i++) {
245+
EXPECT_NEAR(reproduction_numbers[i], checkReproductionNumbers[i], 1e-12);
246+
}
247+
248+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 9;
249+
250+
auto reproduction_numbers2 = model.get_reproduction_numbers(result);
251+
252+
for (int i = 0; i < reproduction_numbers2.size(); i++) {
253+
EXPECT_NEAR(reproduction_numbers2[i], checkReproductionNumbers2[i], 1e-12);
254+
}
255+
256+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 8;
257+
258+
auto reproduction_numbers3 = model.get_reproduction_numbers(result);
259+
260+
for (int i = 0; i < reproduction_numbers2.size(); i++) {
261+
EXPECT_NEAR(reproduction_numbers3[i], checkReproductionNumbers3[i], 1e-12);
262+
}
263+
264+
EXPECT_FALSE(model.get_reproduction_number(static_cast<double>(static_cast<size_t>(result.get_num_time_points())),
265+
result)); //Test for an index that is out of range
266+
}
267+
268+
TEST(TestSeir, get_reproduction_number)
269+
{
270+
mio::oseir::Model model;
271+
272+
double total_population = 10000; //Initialize compartments to get total population of 10000
273+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] = 100;
274+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] = 100;
275+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}] = 100;
276+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Susceptible)}] =
277+
total_population -
278+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] -
279+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] -
280+
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}];
281+
282+
model.parameters.set<mio::oseir::TimeInfected>(6);
283+
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(0.04);
284+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 10;
285+
286+
model.apply_constraints();
287+
288+
mio::TimeSeries<ScalarType> result((int)mio::oseir::InfectionState::Count);
289+
mio::TimeSeries<ScalarType>::Vector result_0(4);
290+
mio::TimeSeries<ScalarType>::Vector result_1(4);
291+
mio::TimeSeries<ScalarType>::Vector result_2(4);
292+
mio::TimeSeries<ScalarType>::Vector result_3(4);
293+
mio::TimeSeries<ScalarType>::Vector result_4(4);
294+
mio::TimeSeries<ScalarType>::Vector result_5(4);
295+
mio::TimeSeries<ScalarType>::Vector result_6(4);
296+
mio::TimeSeries<ScalarType>::Vector result_7(4);
297+
298+
result_0[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9700;
299+
result_1[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.9709149074315;
300+
result_2[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.8404009584538;
301+
result_3[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.260556488618;
302+
result_4[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9696.800490904101;
303+
result_5[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9687.9435082620021;
304+
result_6[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9679.5436372291661;
305+
result_7[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9678.5949381732935;
306+
307+
result.add_time_point(0, result_0);
308+
result.add_time_point(0.001, result_1);
309+
result.add_time_point(0.0055, result_2);
310+
result.add_time_point(0.02575, result_3);
311+
result.add_time_point(0.116875, result_4);
312+
result.add_time_point(0.526938, result_5);
313+
result.add_time_point(0.952226, result_6);
314+
result.add_time_point(1, result_7);
315+
316+
EXPECT_FALSE(model.get_reproduction_number(result.get_time(0) - 0.5, result)); //Test for indices out of range
317+
EXPECT_FALSE(model.get_reproduction_number(result.get_last_time() + 0.5, result));
318+
EXPECT_FALSE(model.get_reproduction_number((size_t)result.get_num_time_points(), result));
319+
320+
EXPECT_EQ(model.get_reproduction_number((size_t)0, result).value(),
321+
model.get_reproduction_number(0.0, result).value());
322+
323+
EXPECT_NEAR(model.get_reproduction_number(0.3, result).value(), 2.3262828383474389859, 1e-12);
324+
EXPECT_NEAR(model.get_reproduction_number(0.7, result).value(), 2.3242860858116172196, 1e-12);
325+
EXPECT_NEAR(model.get_reproduction_number(0.0, result).value(), 2.3280000000000002913, 1e-12);
326+
327+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 9;
328+
EXPECT_NEAR(model.get_reproduction_number(0.1, result).value(), 2.0946073086586665113, 1e-12);
329+
EXPECT_NEAR(model.get_reproduction_number(0.3, result).value(), 2.0936545545126947765, 1e-12);
330+
331+
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 8;
332+
EXPECT_NEAR(model.get_reproduction_number(0.2, result).value(), 1.8614409729718137676, 1e-12);
333+
EXPECT_NEAR(model.get_reproduction_number(0.9, result).value(), 1.858670429549998504, 1e-12);
334+
}

0 commit comments

Comments
 (0)