Skip to content

Commit cb2fddf

Browse files
committed
push scoring to the remote
1 parent 2f61f90 commit cb2fddf

File tree

6 files changed

+46
-87
lines changed

6 files changed

+46
-87
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ reports/report.md
1818
cache/
1919
data/
2020
.vscode/
21+
.envrc
22+
.nhsn_covid_cache.parquet
23+
.nhsn_flu_cache.parquet
24+
meta/

R/targets/score_targets.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ get_external_forecasts <- function(disease) {
22
locations_crosswalk <- get_population_data() %>%
33
select(state_id, state_code) %>%
44
filter(state_id != "usa")
5-
arrow::read_parquet(paste::paste("data/forecasts/{disease}_hosp_forecasts.parquet")) %>%
5+
arrow::read_parquet(glue::glue("data/forecasts/{disease}_hosp_forecasts.parquet")) %>%
66
filter(output_type == "quantile") %>%
77
select(forecaster, geo_value = location, forecast_date, target_end_date, quantile = output_type_id, value) %>%
88
inner_join(locations_crosswalk, by = c("geo_value" = "state_code")) %>%
@@ -49,7 +49,7 @@ score_forecasts <- function(nhsn_latest_data, joined_forecasts_and_ensembles) {
4949
}
5050

5151

52-
render_score_plot <- function(disease) {
52+
render_score_plot <- function(score_report_rmd, scores, forecast_dates, disease) {
5353
rmarkdown::render(
5454
score_report_rmd,
5555
params = list(

R/utils.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,30 @@ update_site <- function(sync_to_s3 = TRUE) {
364364
prod_reports_index <- which(grepl("## Production Reports", report_md_content)) + 1
365365
report_md_content <- append(report_md_content, report_link, after = prod_reports_index)
366366
}
367+
# add scoring notebooks if they exist
368+
score_files <- dir_ls(reports_dir, regexp = ".*_backtesting_2024_2025_on_.*.html")
369+
if (length(score_files) > 0) {
370+
# a tibble of all score files, along with their generation date and disease
371+
score_table <- tibble(
372+
filename = score_files,
373+
dates = str_match_all(filename, "[0-9]{4}-..-..")
374+
) %>%
375+
unnest_wider(dates, names_sep = "_") %>%
376+
rename(generation_date = dates_1) %>%
377+
mutate(
378+
generation_date = ymd(generation_date),
379+
disease = str_match(filename, "flu|covid")
380+
)
381+
used_files <- score_table %>%
382+
group_by(disease) %>%
383+
slice_max(generation_date)
384+
# iterating over the diseases
385+
for (row_num in seq_along(used_files$filename)) {
386+
scoring_index <- which(grepl("### Scoring this season", report_md_content)) + 1
387+
score_link <- sprintf("- [%s Scoring, Rendered %s](%s)", str_to_title(used_files$disease[[row_num]]), used_files$generation_date[[row_num]], used_files$filename[[row_num]])
388+
report_md_content <- append(report_md_content, score_link, after = scoring_index)
389+
}
390+
}
367391

368392
# Write the updated content to report.md
369393
report_md_path <- path(reports_dir, "report.md")

reports/template.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
## Production Reports
66

77

8+
### Scoring this season
9+
10+
811
## Exploration Reports
912

1013
- [NHSN 2024-2025 Data Analysis](new_data.html)

scripts/covid_hosp_prod.R

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ forecast_generation_dates <- Sys.Date()
1414
# Usually, the forecast_date is the same as the generation date, but you can
1515
# override this. It should be a Wednesday.
1616
forecast_dates <- round_date(forecast_generation_dates, "weeks", week_start = 3)
17+
18+
1719
# forecast_generation_date needs to follow suit, but it's more complicated
1820
# because sometimes we forecast on Thursday.
1921
# forecast_generation_dates <- c(as.Date(c("2024-11-20", "2024-11-27", "2024-12-04", "2024-12-11", "2024-12-18", "2024-12-26", "2025-01-02")), seq.Date(as.Date("2025-01-08"), Sys.Date(), by = 7L))
@@ -412,15 +414,7 @@ if (backtest_mode) {
412414
tar_target(
413415
external_forecasts,
414416
command = {
415-
locations_crosswalk <- get_population_data() %>%
416-
select(state_id, state_code) %>%
417-
filter(state_id != "usa")
418-
arrow::read_parquet("data/forecasts/covid_hosp_forecasts.parquet") %>%
419-
filter(output_type == "quantile") %>%
420-
select(forecaster, geo_value = location, forecast_date, target_end_date, quantile = output_type_id, value) %>%
421-
inner_join(locations_crosswalk, by = c("geo_value" = "state_code")) %>%
422-
mutate(geo_value = state_id) %>%
423-
select(forecaster, geo_value, forecast_date, target_end_date, quantile, value)
417+
get_external_forecasts("covid")
424418
}
425419
),
426420
tar_combine(
@@ -433,58 +427,13 @@ if (backtest_mode) {
433427
tar_target(
434428
name = scores,
435429
command = {
436-
truth_data <-
437-
nhsn_latest_data %>%
438-
select(geo_value, target_end_date = time_value, oracle_value = value) %>%
439-
left_join(
440-
get_population_data() %>%
441-
select(state_id, state_code),
442-
by = c("geo_value" = "state_id")
443-
) %>%
444-
drop_na() %>%
445-
rename(location = state_code) %>%
446-
select(-geo_value)
447-
forecasts_formatted <-
448-
joined_forecasts_and_ensembles %>%
449-
format_scoring_utils(disease = "covid")
450-
scores <- forecasts_formatted %>%
451-
filter(location != "US") %>%
452-
hubEvals::score_model_out(
453-
truth_data,
454-
metrics = c("wis", "ae_median", "interval_coverage_50", "interval_coverage_90"),
455-
summarize = FALSE,
456-
by = c("model_id")
457-
)
458-
scores %>%
459-
left_join(
460-
get_population_data() %>%
461-
select(state_id, state_code),
462-
by = c("location" = "state_code")
463-
) %>%
464-
rename(
465-
forecaster = model_id,
466-
forecast_date = reference_date,
467-
ahead = horizon,
468-
geo_value = state_id
469-
) %>%
470-
select(-location)
430+
score_forecasts(nhsn_latest_data, joined_forecasts_and_ensembles)
471431
}
472432
),
473433
tar_target(
474434
name = score_plot,
475435
command = {
476-
rmarkdown::render(
477-
score_report_rmd,
478-
params = list(
479-
scores = scores,
480-
forecast_dates = forecast_dates,
481-
disease = "covid"
482-
),
483-
output_file = here::here(
484-
"reports",
485-
sprintf("covid_backtesting_2024_2025_on_%s.html", as.Date(Sys.Date()))
486-
)
487-
)
436+
render_score_plot(score_report_rmd, scores, forecast_dates, "covid")
488437
},
489438
cue = tar_cue("always")
490439
)

scripts/flu_hosp_prod.R

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,7 @@ if (backtest_mode) {
467467
tar_target(
468468
external_forecasts,
469469
command = {
470-
locations_crosswalk <- get_population_data() %>%
471-
select(state_id, state_code) %>%
472-
filter(state_id != "usa")
473-
arrow::read_parquet("data/forecasts/flu_hosp_forecasts.parquet") %>%
474-
filter(output_type == "quantile") %>%
475-
select(forecaster, geo_value = location, forecast_date, target_end_date, quantile = output_type_id, value) %>%
476-
inner_join(locations_crosswalk, by = c("geo_value" = "state_code")) %>%
477-
mutate(geo_value = state_id) %>%
478-
select(forecaster, geo_value, forecast_date, target_end_date, quantile, value)
470+
get_external_forecasts("flu")
479471
}
480472
),
481473
tar_combine(
@@ -488,31 +480,18 @@ if (backtest_mode) {
488480
tar_target(
489481
name = scores,
490482
command = {
491-
truth_data <- nhsn_latest_data %>%
492-
select(geo_value, target_end_date = time_value, true_value = value) %>%
493-
mutate(target_end_date = target_end_date + 3)
494-
joined_forecasts_and_ensembles %>%
495-
select(-source) %>%
496-
rename("model" = "forecaster", "prediction" = "value") %>%
497-
evaluate_predictions(forecasts = ., truth_data = truth_data) %>%
498-
rename("forecaster" = "model")
483+
nhsn_latest_end_of_week <-
484+
nhsn_latest_data %>%
485+
mutate(
486+
time_value = ceiling_date(time_value, unit = "week", week_start = 6)
487+
)
488+
score_forecasts(nhsn_latest_end_of_week, joined_forecasts_and_ensembles)
499489
}
500490
),
501491
tar_target(
502492
name = score_plot,
503493
command = {
504-
rmarkdown::render(
505-
score_report_rmd,
506-
params = list(
507-
scores = scores,
508-
forecast_dates = forecast_dates,
509-
disease = "flu"
510-
),
511-
output_file = here::here(
512-
"reports",
513-
sprintf("flu_backtesting_2024_2025_on_%s.html", as.Date(Sys.Date()))
514-
)
515-
)
494+
render_score_plot(score_report_rmd, scores, forecast_dates, "flu")
516495
}
517496
)
518497
)

0 commit comments

Comments
 (0)