14
14
import numpy as np
15
15
import pandas as pd
16
16
17
+ from ..constants import PREFIT_ADDITIONAL_DAYS
17
18
from .parameters import Parameters
18
19
19
20
@@ -68,31 +69,42 @@ def __init__(self, p: Parameters):
68
69
69
70
if p .mitigation_date is None :
70
71
self .i_day = 0 # seed to the full length
71
- raw = self .run_projection (p , [(self .beta , p .n_days )])
72
+ raw = self .run_projection (p , [
73
+ (self .beta , p .n_days + PREFIT_ADDITIONAL_DAYS )])
72
74
self .i_day = i_day = int (get_argmin_ds (raw ["census_hospitalized" ], p .current_hospitalized ))
73
75
74
- self .raw = self .run_projection (p , self .gen_policy (p ))
76
+ self .raw = self .run_projection (p , self .get_policies (p ))
75
77
76
78
logger .info ('Set i_day = %s' , i_day )
77
79
else :
78
- projections = {}
79
80
best_i_day = - 1
80
81
best_i_day_loss = float ('inf' )
81
- for i_day in range (p .n_days ):
82
- self .i_day = i_day
83
- raw = self .run_projection (p , self .gen_policy (p ))
82
+ for self .i_day in range (p .n_days + PREFIT_ADDITIONAL_DAYS ):
83
+ mitigation_day = - (p .current_date - p .mitigation_date ).days
84
+ if mitigation_day < - self .i_day :
85
+ mitigation_day = - self .i_day
86
+
87
+ total_days = self .i_day + p .n_days + PREFIT_ADDITIONAL_DAYS
88
+ pre_mitigation_days = self .i_day + mitigation_day
89
+ post_mitigation_days = total_days - pre_mitigation_days
90
+
91
+ raw = self .run_projection (p , [
92
+ (self .beta , pre_mitigation_days ),
93
+ (self .beta_t , post_mitigation_days ),
94
+ ]
95
+ )
84
96
85
97
# Don't fit against results that put the peak before the present day
86
- if raw ["census_hospitalized" ].argmax () < i_day :
98
+ if raw ["census_hospitalized" ].argmax () < self . i_day :
87
99
continue
88
100
89
- loss = get_loss (raw ["census_hospitalized" ][i_day ], p .current_hospitalized )
101
+ loss = get_loss (raw ["census_hospitalized" ][self . i_day ], p .current_hospitalized )
90
102
if loss < best_i_day_loss :
91
103
best_i_day_loss = loss
92
- best_i_day = i_day
93
- self .raw = raw
104
+ best_i_day = self .i_day
94
105
95
106
self .i_day = best_i_day
107
+ self .raw = self .run_projection (p , self .get_policies (p ))
96
108
97
109
logger .info (
98
110
'Estimated date_first_hospitalized: %s; current_date: %s; i_day: %s' ,
@@ -127,7 +139,7 @@ def __init__(self, p: Parameters):
127
139
intrinsic_growth_rate = get_growth_rate (p .doubling_time )
128
140
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
129
141
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
130
- self .raw = self .run_projection (p , self .gen_policy (p ))
142
+ self .raw = self .run_projection (p , self .get_policies (p ))
131
143
132
144
self .population = p .population
133
145
else :
@@ -162,6 +174,15 @@ def __init__(self, p: Parameters):
162
174
'census_icu' : self .raw ['census_icu' ],
163
175
'census_ventilated' : self .raw ['census_ventilated' ],
164
176
})
177
+ self .ppe_df = pd .DataFrame (data = {
178
+ 'day' : self .raw ['day' ],
179
+ 'date' : self .raw ['date' ],
180
+ 'census_hospitalized' : self .raw ['census_hospitalized' ],
181
+ 'census_icu' : self .raw ['census_icu' ],
182
+ 'census_ventilated' : self .raw ['census_ventilated' ],
183
+ 'admits_hospitalized' : self .raw ['admits_hospitalized' ],
184
+ })
185
+ self .ppe_df = self .ppe_df [self .ppe_df ['day' ]>= 0 ]
165
186
166
187
logger .info ('len(np.arange(-i_day, n_days+1)): %s' , len (np .arange (- self .i_day , p .n_days + 1 )))
167
188
logger .info ('len(raw_df): %s' , len (self .raw_df ))
@@ -195,9 +216,9 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
195
216
intrinsic_growth_rate = get_growth_rate (i_dt )
196
217
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
197
218
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
198
-
199
- raw = self .run_projection (p , self .gen_policy (p ))
200
-
219
+
220
+ raw = self .run_projection (p , self .get_policies (p ))
221
+
201
222
# Skip values the would put the fit past peak
202
223
peak_admits_day = raw ["admits_hospitalized" ].argmax ()
203
224
if peak_admits_day < 0 :
@@ -210,7 +231,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
210
231
min_loss = pd .Series (losses ).argmin ()
211
232
return min_loss
212
233
213
- def gen_policy (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
234
+ def get_policies (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
214
235
if p .mitigation_date is not None :
215
236
mitigation_day = - (p .current_date - p .mitigation_date ).days
216
237
else :
0 commit comments