Skip to content

Commit e5a1953

Browse files
723 update support_max after setting a new parameter (#724)
Co-authored-by: Martin J. Kühn <[email protected]>
1 parent 68cbd91 commit e5a1953

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

cpp/memilio/epidemiology/state_age_function.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,18 @@ struct StateAgeFunction {
132132
* @brief Set the m_parameter object.
133133
*
134134
* Can be used to set the m_parameter object, which specifies the used function.
135+
* The maximum support of a function may be costly to evaluate. In order to not always reevaluate or recompute the
136+
* support when the user asks for it, a cached value is used. If m_support_max is set to -1, the cached value is
137+
* deleted and a recomputation is done the next time the user asks for the support. As the support (potentially)
138+
* depends on the m_parameter object, the cached value has to be deleted. For details see get_support_max().
135139
*
136140
*@param[in] new_parameter New parameter for StateAgeFunction.
137141
*/
138142
void set_parameter(ScalarType new_parameter)
139143
{
140144
m_parameter = new_parameter;
145+
146+
m_support_max = -1.;
141147
}
142148

143149
/**
@@ -156,7 +162,7 @@ struct StateAgeFunction {
156162
*/
157163
virtual ScalarType get_support_max(ScalarType dt, ScalarType tol = 1e-10)
158164
{
159-
ScalarType support_max = 0;
165+
ScalarType support_max = 0.;
160166

161167
if (!floating_point_equal(m_support_tol, tol, 1e-14) || floating_point_equal(m_support_max, -1., 1e-14)) {
162168
while (eval(support_max) >= tol) {
@@ -288,7 +294,8 @@ struct SmootherCosine : public StateAgeFunction {
288294
{
289295
unused(dt);
290296
unused(tol);
291-
return m_parameter;
297+
m_support_max = m_parameter;
298+
return m_support_max;
292299
}
293300

294301
protected:
@@ -350,11 +357,12 @@ struct ConstantFunction : public StateAgeFunction {
350357

351358
unused(dt);
352359
unused(tol);
360+
m_support_max = -2.;
353361

354362
log_error("This function is not suited to be a TransitionDistribution. Do not call in case of StateAgeFunctions"
355363
"of type b); see documentation of StateAgeFunction Base class.");
356364

357-
return (ScalarType)(-2);
365+
return m_support_max;
358366
}
359367

360368
protected:

cpp/tests/test_ide_secir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ TEST(IdeSecir, checkProportionRecoveredDeath)
432432
init.add_time_point(init.get_last_time() + dt, vec_init);
433433
}
434434

435-
// Initialize two models.
435+
// Initialize model.
436436
mio::isecir::Model model(std::move(init), N, Dead_before);
437437

438438
// Set working parameters.

cpp/tests/test_state_age_function.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,21 @@ TEST(TestStateAgeFunction, testGetSupportMax)
104104
ScalarType dt = 0.5;
105105

106106
// test get_support_max for all derived classes as this method can be overridden
107+
// Check that the maximum support is correct after setting the parameter object of a StateAgeFunction.
107108
mio::ExponentialDecay expdecay(1.0);
108109
EXPECT_NEAR(expdecay.get_support_max(dt), 23.5, 1e-14);
110+
expdecay.set_parameter(2.0);
111+
EXPECT_NEAR(expdecay.get_support_max(dt), 12.0, 1e-14);
109112

110113
mio::SmootherCosine smoothcos(1.0);
111114
EXPECT_NEAR(smoothcos.get_support_max(dt), 1.0, 1e-14);
115+
smoothcos.set_parameter(2.0);
116+
EXPECT_NEAR(smoothcos.get_support_max(dt), 2.0, 1e-14);
112117

113118
mio::ConstantFunction constfunc(1.0);
114119
EXPECT_NEAR(constfunc.get_support_max(dt), -2.0, 1e-14);
120+
constfunc.set_parameter(2.0);
121+
EXPECT_NEAR(constfunc.get_support_max(dt), -2.0, 1e-14);
115122
}
116123

117124
TEST(TestStateAgeFunction, testSAFWrapperSpecialMember)
@@ -150,7 +157,7 @@ TEST(TestStateAgeFunction, testSAFWrapperSpecialMember)
150157
// test true copy, not reference
151158
wrapper.set_parameter(2.0);
152159
EXPECT_NE(wrapper.get_parameter(), wrapper4.get_parameter());
153-
wrapper.set_parameter(1.0);
160+
wrapper.set_parameter(1.0);
154161

155162
// move assignment
156163
mio::StateAgeFunctionWrapper wrapper5 = std::move(wrapper4);

0 commit comments

Comments
 (0)