Skip to content

Commit 665e897

Browse files
committed
Explicitly output new infections, rather than calculating it post-hoc
1 parent 03abb88 commit 665e897

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

src/penn_chime/models.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self, p: Parameters) -> SimSirModel:
100100

101101
def sir(
102102
s: float, i: float, r: float, beta: float, gamma: float, n: float
103-
) -> Tuple[float, float, float]:
103+
) -> Tuple[float, float, float, float]:
104104
"""The SIR model, one time step."""
105105

106106
potential_new_infections = beta * s * i
@@ -119,18 +119,18 @@ def sir(
119119
r_n = 0.0
120120

121121
scale = n / (s_n + i_n + r_n)
122-
return s_n * scale, i_n * scale, r_n * scale
122+
return s_n * scale, i_n * scale, r_n * scale, new_infections * scale
123123

124124

125125
def gen_sir(
126126
s: float, i: float, r: float, beta: float, gamma: float, n_days: int
127-
) -> Generator[Tuple[float, float, float], None, None]:
127+
) -> Generator[Tuple[float, float, float, float], None, None]:
128128
"""Simulate SIR model forward in time yielding tuples."""
129-
s, i, r = (float(v) for v in (s, i, r))
129+
s, i, r, new_infections = (float(v) for v in (s, i, r, 0.0))
130130
n = s + i + r
131131
for d in range(n_days + 1):
132-
yield d, s, i, r
133-
s, i, r = sir(s, i, r, beta, gamma, n)
132+
yield d, s, i, r, new_infections
133+
s, i, r, new_infections = sir(s, i, r, beta, gamma, n)
134134

135135

136136
def sim_sir_df(
@@ -139,7 +139,7 @@ def sim_sir_df(
139139
"""Simulate the SIR model forward in time."""
140140
return pd.DataFrame(
141141
data=gen_sir(s, i, r, beta, gamma, n_days),
142-
columns=("day", "susceptible", "infected", "recovered"),
142+
columns=("day", "susceptible", "infected", "recovered", "new_infections"),
143143
)
144144

145145
def build_dispositions_df(

tests/test_app.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,12 @@ def test_sir():
137137
Someone who is good at testing, help
138138
"""
139139
sir_test = sir(100, 1, 0, 0.2, 0.5, 1)
140-
assert sir_test == (
141-
0.7920792079207921,
142-
0.20297029702970298,
143-
0.0049504950495049506,
144-
), "This contrived example should work"
145-
146-
assert isinstance(sir_test, tuple)
147-
for v in sir_test:
140+
s, i, r, i_n = sir_test
141+
assert s == 0.7920792079207921, "This contrived example should work"
142+
assert i == 0.20297029702970298, "This contrived example should work"
143+
assert r == 0.0049504950495049506, "This contrived example should work"
144+
145+
for v in (sir_test):
148146
assert isinstance(v, float)
149147
assert v >= 0
150148

0 commit comments

Comments
 (0)