Skip to content

Commit 2b609bf

Browse files
HenrZuMaxBetzDLR
andauthored
1438 normalize transitions in ode mseirs4 model (#1439)
- Normalize force of infection by total population. - Log in the example now way shorter. Co-authored-by: MaxBetz <[email protected]>
1 parent edb7428 commit 2b609bf

File tree

2 files changed

+69
-15
lines changed

2 files changed

+69
-15
lines changed

cpp/models/ode_mseirs4/model.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,35 @@ class Model : public mio::CompartmentalModel<FP, InfectionState, mio::Population
7777
return static_cast<size_t>(s);
7878
};
7979

80-
FP I_total = y[idx(InfectionState::I1)] + y[idx(InfectionState::I2)] + y[idx(InfectionState::I3)] +
81-
y[idx(InfectionState::I4)];
82-
FP R_total = y[idx(InfectionState::R1)] + y[idx(InfectionState::R2)] + y[idx(InfectionState::R3)] +
83-
y[idx(InfectionState::R4)];
84-
FP N = pop.sum();
80+
const FP I_total = y[idx(InfectionState::I1)] + y[idx(InfectionState::I2)] + y[idx(InfectionState::I3)] +
81+
y[idx(InfectionState::I4)];
82+
const FP R_total = y[idx(InfectionState::R1)] + y[idx(InfectionState::R2)] + y[idx(InfectionState::R3)] +
83+
y[idx(InfectionState::R4)];
84+
const FP N = pop.sum();
85+
const FP inv_N = (N > Limits<FP>::zero_tolerance())
86+
? FP(1) / N
87+
: FP(0.0); // avoid excessive force of infection or division by zero when empty
88+
const FP lambda1 = beta1 * I_total * inv_N;
89+
const FP lambda2 = beta2 * I_total * inv_N;
90+
const FP lambda3 = beta3 * I_total * inv_N;
91+
const FP lambda4 = beta4 * I_total * inv_N;
8592

8693
// dM
8794
dydt[idx(InfectionState::MaternalImmune)] = mu * R_total - (xi + mu) * y[idx(InfectionState::MaternalImmune)];
8895

8996
// dS1
9097
dydt[idx(InfectionState::S1)] = mu * (N - R_total) + xi * y[idx(InfectionState::MaternalImmune)] -
91-
mu * y[idx(InfectionState::S1)] - beta1 * I_total * y[idx(InfectionState::S1)];
98+
mu * y[idx(InfectionState::S1)] - lambda1 * y[idx(InfectionState::S1)];
9299

93100
// dE1..E4
94101
dydt[idx(InfectionState::E1)] =
95-
beta1 * I_total * y[idx(InfectionState::S1)] - (mu + sigma) * y[idx(InfectionState::E1)];
102+
lambda1 * y[idx(InfectionState::S1)] - (mu + sigma) * y[idx(InfectionState::E1)];
96103
dydt[idx(InfectionState::E2)] =
97-
beta2 * I_total * y[idx(InfectionState::S2)] - (mu + sigma) * y[idx(InfectionState::E2)];
104+
lambda2 * y[idx(InfectionState::S2)] - (mu + sigma) * y[idx(InfectionState::E2)];
98105
dydt[idx(InfectionState::E3)] =
99-
beta3 * I_total * y[idx(InfectionState::S3)] - (mu + sigma) * y[idx(InfectionState::E3)];
106+
lambda3 * y[idx(InfectionState::S3)] - (mu + sigma) * y[idx(InfectionState::E3)];
100107
dydt[idx(InfectionState::E4)] =
101-
beta4 * I_total * y[idx(InfectionState::S4)] - (mu + sigma) * y[idx(InfectionState::E4)];
108+
lambda4 * y[idx(InfectionState::S4)] - (mu + sigma) * y[idx(InfectionState::E4)];
102109

103110
// dI1..I4
104111
dydt[idx(InfectionState::I1)] = sigma * y[idx(InfectionState::E1)] - (nu + mu) * y[idx(InfectionState::I1)];
@@ -113,12 +120,12 @@ class Model : public mio::CompartmentalModel<FP, InfectionState, mio::Population
113120
dydt[idx(InfectionState::R4)] = nu * y[idx(InfectionState::I4)] - (mu + gamma) * y[idx(InfectionState::R4)];
114121

115122
// dS2,S3,S4
116-
dydt[idx(InfectionState::S2)] = gamma * y[idx(InfectionState::R1)] - mu * y[idx(InfectionState::S2)] -
117-
beta2 * I_total * y[idx(InfectionState::S2)];
118-
dydt[idx(InfectionState::S3)] = gamma * y[idx(InfectionState::R2)] - mu * y[idx(InfectionState::S3)] -
119-
beta3 * I_total * y[idx(InfectionState::S3)];
123+
dydt[idx(InfectionState::S2)] =
124+
gamma * y[idx(InfectionState::R1)] - mu * y[idx(InfectionState::S2)] - lambda2 * y[idx(InfectionState::S2)];
125+
dydt[idx(InfectionState::S3)] =
126+
gamma * y[idx(InfectionState::R2)] - mu * y[idx(InfectionState::S3)] - lambda3 * y[idx(InfectionState::S3)];
120127
dydt[idx(InfectionState::S4)] = gamma * (y[idx(InfectionState::R3)] + y[idx(InfectionState::R4)]) -
121-
mu * y[idx(InfectionState::S4)] - beta4 * I_total * y[idx(InfectionState::S4)];
128+
mu * y[idx(InfectionState::S4)] - lambda4 * y[idx(InfectionState::S4)];
122129
}
123130

124131
/**

cpp/tests/test_odemseirs4.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,50 @@ TEST(TestOdeMseirs4, Simulation)
205205

206206
EXPECT_EQ(sim.get_num_time_points(), 2);
207207
}
208+
209+
TEST(TestOdeMseirs4, normalized_transitions)
210+
{
211+
// case: get derivative with isolated infection terms; expect derivative match computed values
212+
mio::omseirs4::Model<double> model;
213+
auto& params = model.parameters;
214+
params.get<mio::omseirs4::BaseTransmissionRate<double>>() = 0.4;
215+
216+
// disable other flows to isolate infection terms
217+
params.get<mio::omseirs4::NaturalBirthDeathRate<double>>() = 0.0;
218+
params.get<mio::omseirs4::LossMaternalImmunityRate<double>>() = 0.0;
219+
params.get<mio::omseirs4::ProgressionRate<double>>() = 0.0;
220+
params.get<mio::omseirs4::RecoveryRate<double>>() = 0.0;
221+
params.get<mio::omseirs4::ImmunityWaningRate<double>>() = 0.0;
222+
params.get<mio::omseirs4::SeasonalAmplitude<double>>() = 0.0;
223+
224+
using IS = mio::omseirs4::InfectionState;
225+
model.populations[{mio::Index<IS>(IS::S1)}] = 500.0;
226+
model.populations[{mio::Index<IS>(IS::S2)}] = 300.0;
227+
model.populations[{mio::Index<IS>(IS::S3)}] = 200.0;
228+
model.populations[{mio::Index<IS>(IS::S4)}] = 100.0;
229+
model.populations[{mio::Index<IS>(IS::I1)}] = 10.0;
230+
model.populations[{mio::Index<IS>(IS::I2)}] = 5.0;
231+
model.populations[{mio::Index<IS>(IS::I3)}] = 2.0;
232+
model.populations[{mio::Index<IS>(IS::I4)}] = 3.0;
233+
234+
auto y0 = model.get_initial_values();
235+
auto dydt = Eigen::VectorXd((Eigen::Index)IS::Count);
236+
model.get_derivatives(y0, y0, 0.0, dydt);
237+
238+
const double N = 500.0 + 300.0 + 200.0 + 100.0 + 10.0 + 5.0 + 2.0 + 3.0;
239+
const double I_total = 10.0 + 5.0 + 2.0 + 3.0;
240+
const double lambda1 = 0.4 * (I_total / N);
241+
const double lambda2 = 0.5 * lambda1;
242+
const double lambda3 = 0.35 * lambda1;
243+
const double lambda4 = 0.25 * lambda1;
244+
const double tol = 1e-12;
245+
246+
EXPECT_NEAR(dydt[(Eigen::Index)IS::S1], -lambda1 * 500.0, tol);
247+
EXPECT_NEAR(dydt[(Eigen::Index)IS::E1], lambda1 * 500.0, tol);
248+
EXPECT_NEAR(dydt[(Eigen::Index)IS::S2], -lambda2 * 300.0, tol);
249+
EXPECT_NEAR(dydt[(Eigen::Index)IS::E2], lambda2 * 300.0, tol);
250+
EXPECT_NEAR(dydt[(Eigen::Index)IS::S3], -lambda3 * 200.0, tol);
251+
EXPECT_NEAR(dydt[(Eigen::Index)IS::E3], lambda3 * 200.0, tol);
252+
EXPECT_NEAR(dydt[(Eigen::Index)IS::S4], -lambda4 * 100.0, tol);
253+
EXPECT_NEAR(dydt[(Eigen::Index)IS::E4], lambda4 * 100.0, tol);
254+
}

0 commit comments

Comments
 (0)