@@ -173,6 +173,7 @@ def pairwise(container: list):
173
173
next (b , None )
174
174
return zip (a , b )
175
175
176
+ LOGGER .debug (f"{ self .__class__ .__name__ } : Calc risk periods" )
176
177
# impfset = self._merge_impfset(snapshots)
177
178
return [
178
179
CalcRiskPeriod (
@@ -227,26 +228,25 @@ def _generic_metrics(
227
228
# Construct the attribute name for storing the metric results
228
229
attr_name = f"_{ metric_name } _metrics"
229
230
230
- if getattr (self , attr_name , None ) is None :
231
- tmp = []
232
- for calc_period in self .risk_periods :
233
- # Call the specified method on the calc_period object
234
- tmp .append (getattr (calc_period , metric_meth )(** kwargs ))
235
-
236
- tmp = pd .concat (tmp )
237
- tmp .drop_duplicates (inplace = True )
238
- tmp ["group" ] = tmp ["group" ].fillna (self ._all_groups_name )
239
- columns_to_front = ["group" , "date" , "measure" , "metric" ]
240
- tmp = tmp [
241
- columns_to_front
242
- + [
243
- col
244
- for col in tmp .columns
245
- if col not in columns_to_front + ["group" , "risk" , "rp" ]
246
- ]
247
- + ["risk" ]
231
+ tmp = []
232
+ for calc_period in self .risk_periods :
233
+ # Call the specified method on the calc_period object
234
+ tmp .append (getattr (calc_period , metric_meth )(** kwargs ))
235
+
236
+ tmp = pd .concat (tmp )
237
+ tmp .drop_duplicates (inplace = True )
238
+ tmp ["group" ] = tmp ["group" ].fillna (self ._all_groups_name )
239
+ columns_to_front = ["group" , "date" , "measure" , "metric" ]
240
+ tmp = tmp [
241
+ columns_to_front
242
+ + [
243
+ col
244
+ for col in tmp .columns
245
+ if col not in columns_to_front + ["group" , "risk" , "rp" ]
248
246
]
249
- setattr (self , attr_name , tmp )
247
+ + ["risk" ]
248
+ ]
249
+ setattr (self , attr_name , tmp )
250
250
251
251
if npv :
252
252
return self .npv_transform (getattr (self , attr_name ), self .risk_disc )
@@ -271,40 +271,39 @@ def _compute_metrics(
271
271
)
272
272
return df
273
273
274
- def eai_metrics (self , npv : bool = True ):
274
+ def eai_metrics (self , npv : bool = True , ** kwargs ):
275
275
return self ._compute_metrics (
276
- npv = npv ,
277
- metric_name = "eai" ,
278
- metric_meth = "calc_eai_gdf" ,
276
+ npv = npv , metric_name = "eai" , metric_meth = "calc_eai_gdf" , ** kwargs
279
277
)
280
278
281
- def aai_metrics (self , npv : bool = True ):
279
+ def aai_metrics (self , npv : bool = True , ** kwargs ):
282
280
return self ._compute_metrics (
283
- npv = npv ,
284
- metric_name = "aai" ,
285
- metric_meth = "calc_aai_metric" ,
281
+ npv = npv , metric_name = "aai" , metric_meth = "calc_aai_metric" , ** kwargs
286
282
)
287
283
288
- def return_periods_metrics (self , return_periods , npv : bool = True ):
284
+ def return_periods_metrics (self , return_periods , npv : bool = True , ** kwargs ):
289
285
return self ._compute_metrics (
290
286
npv = npv ,
291
287
metric_name = "return_periods" ,
292
288
metric_meth = "calc_return_periods_metric" ,
293
289
return_periods = return_periods ,
290
+ ** kwargs ,
294
291
)
295
292
296
- def aai_per_group_metrics (self , npv : bool = True ):
293
+ def aai_per_group_metrics (self , npv : bool = True , ** kwargs ):
297
294
return self ._compute_metrics (
298
295
npv = npv ,
299
296
metric_name = "aai_per_group" ,
300
297
metric_meth = "calc_aai_per_group_metric" ,
298
+ ** kwargs ,
301
299
)
302
300
303
- def risk_components_metrics (self , npv : bool = True ):
301
+ def risk_components_metrics (self , npv : bool = True , ** kwargs ):
304
302
return self ._compute_metrics (
305
303
npv = npv ,
306
304
metric_name = "risk_components" ,
307
305
metric_meth = "calc_risk_components_metric" ,
306
+ ** kwargs ,
308
307
)
309
308
310
309
def per_date_risk_metrics (
0 commit comments