diff --git a/.gitignore b/.gitignore index 412a0b00..7eb03854 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,4 @@ data/ .nhsn_flu_cache.parquet meta/ **/unnamed-chunk* -decreasing_forecasters_cache/ \ No newline at end of file +decreasing_forecasters_cache/ diff --git a/R/forecasters/formatters.R b/R/forecasters/formatters.R index ce42c0bb..3b2669b4 100644 --- a/R/forecasters/formatters.R +++ b/R/forecasters/formatters.R @@ -58,16 +58,16 @@ format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_ format_flusight <- function(pred, disease = c("flu", "covid")) { disease <- arg_match(disease) pred %>% + add_state_info(geo_value_col = "geo_value", old_geo_code = "state_id", new_geo_code = "state_code") %>% mutate( reference_date = get_forecast_reference_date(forecast_date), target = glue::glue("wk inc {disease} hosp"), horizon = as.integer(floor((target_end_date - reference_date) / 7)), output_type = "quantile", output_type_id = quantile, - value = value + value = value, + location = state_code ) %>% - left_join(get_population_data() %>% select(state_id, state_code), by = c("geo_value" = "state_id")) %>% - mutate(location = state_code) %>% select(reference_date, target, horizon, target_end_date, location, output_type, output_type_id, value) } diff --git a/R/utils.R b/R/utils.R index 45fe863c..bfc1518a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -316,6 +316,19 @@ write_submission_file <- function(pred, forecast_reference_date, submission_dire readr::write_csv(pred, file_path) } +#' The quantile levels used by the covidhub repository +#' +#' @param type either standard or inc_case, with inc_case being a small subset of the standard +#' +#' @export +covidhub_probs <- function(type = c("standard", "inc_case")) { + type <- match.arg(type) + switch(type, + standard = c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99), + inc_case = c(0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975) + ) |> round(digits = 3) +} + #' Utility to get the reference date for a given date. This is the last day of #' the epiweek that the date falls in. get_forecast_reference_date <- function(date) { diff --git a/_targets.yaml b/_targets.yaml index 3f31a6e6..5f99959c 100644 --- a/_targets.yaml +++ b/_targets.yaml @@ -18,4 +18,7 @@ covid_hosp_prod: store: covid_hosp_prod use_crew: yes reporter_make: timestamp - +dashboard-proj: + script: scripts/dashboard-proj.R + store: dashboard-proj + use_crew: yes diff --git a/get_forecast_data.r b/get_forecast_data.r new file mode 100644 index 00000000..c9f2e271 --- /dev/null +++ b/get_forecast_data.r @@ -0,0 +1,140 @@ +library(tidyverse) +library(httr) +library(lubridate) +library(progress) + +options(readr.show_progress = FALSE) +options(readr.show_col_types = FALSE) + + +# Configuration +config <- list( + base_url = "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output", + forecasters = c("CMU-TimeSeries", "FluSight-baseline", "FluSight-ensemble", "FluSight-base_seasonal", "UMass-flusion"), + local_storage = "data/forecasts", + tracking_file = "data/download_tracking.csv" +) + +# Function to ensure directory structure exists +setup_directories <- function(base_dir) { + dir.create(file.path(base_dir), recursive = TRUE, showWarnings = FALSE) + for (forecaster in config$forecasters) { + dir.create(file.path(base_dir, forecaster), recursive = TRUE, showWarnings = FALSE) + } +} + +# Function to load tracking data +load_tracking_data <- function() { + if (file.exists(config$tracking_file)) { + read_csv(config$tracking_file) + } else { + tibble( + forecaster = character(), + filename = character(), + download_date = character(), + status = character() + ) + } +} + +# Function to generate possible filenames for a date range +generate_filenames <- function(start_date, end_date, forecaster) { + dates <- seq(as_date(start_date), as_date(end_date), by = "week") + filenames <- paste0( + format(dates, "%Y-%m-%d"), + "-", + forecaster, + ".csv" + ) + return(filenames) +} + +# Function to check if file exists on GitHub +check_github_file <- function(forecaster, filename) { + url <- paste0(config$base_url, "/", forecaster, "/", filename) + response <- GET(url) + return(status_code(response) == 200) +} + +# Function to download a single file +download_forecast_file <- function(forecaster, filename) { + url <- paste0(config$base_url, "/", forecaster, "/", filename) + local_path <- file.path(config$local_storage, forecaster, filename) + + tryCatch( + { + download.file(url, local_path, mode = "wb", quiet = TRUE) + return("success") + }, + error = function(e) { + return("failed") + } + ) +} + +# Main function to update forecast files +update_forecast_files <- function(days_back = 30) { + # Setup + setup_directories(config$local_storage) + tracking_data <- load_tracking_data() + + # Generate date range + end_date <- Sys.Date() + start_date <- get_forecast_reference_date(end_date - days_back) + + # Process each forecaster + new_tracking_records <- list() + + pb_forecasters <- progress_bar$new( + format = "Downloading forecasts from :forecaster [:bar] :percent :eta", + total = length(config$forecasters), + clear = FALSE, + width = 60 + ) + + for (forecaster in config$forecasters) { + pb_forecasters$tick(tokens = list(forecaster = forecaster)) + + # Get potential filenames + filenames <- generate_filenames(start_date, end_date, forecaster) + + # Filter out already downloaded files + existing_files <- tracking_data %>% + filter(forecaster == !!forecaster, status == "success") %>% + pull(filename) + + new_files <- setdiff(filenames, existing_files) + + if (length(new_files) > 0) { + # Create nested progress bar for files + pb_files <- progress_bar$new( + format = " Downloading files [:bar] :current/:total :filename", + total = length(new_files) + ) + + for (filename in new_files) { + pb_files$tick(tokens = list(filename = filename)) + + if (check_github_file(forecaster, filename)) { + status <- download_forecast_file(forecaster, filename) + + new_tracking_records[[length(new_tracking_records) + 1]] <- tibble( + forecaster = forecaster, + filename = filename, + download_date = as.character(Sys.time()), + status = status + ) + } + } + } + } + + # Update tracking data + if (length(new_tracking_records) > 0) { + new_tracking_data <- bind_rows(new_tracking_records) + tracking_data <- bind_rows(tracking_data, new_tracking_data) + write_csv(tracking_data, config$tracking_file) + } + + return(tracking_data) +} diff --git a/scripts/dashboard-proj.R b/scripts/dashboard-proj.R new file mode 100644 index 00000000..e69de29b diff --git a/scripts/reports/forecast_dashboard.Rmd b/scripts/reports/forecast_dashboard.Rmd new file mode 100644 index 00000000..ff921854 --- /dev/null +++ b/scripts/reports/forecast_dashboard.Rmd @@ -0,0 +1,319 @@ +--- +title: "Disease Surveillance Dashboard" +output: + flexdashboard::flex_dashboard: + orientation: columns + vertical_layout: fill +runtime: shiny +--- + +```{r setup, include=FALSE} +library(tidyverse) +library(httr) +library(lubridate) +library(progress) +library(targets) +source(here::here("R", "load_all.R")) + +options(readr.show_progress = FALSE) +options(readr.show_col_types = FALSE) +insufficient_data_geos <- c("as", "mp", "vi", "gu") + + +config <- list( + base_url = "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output", + forecasters = c( + "CMU-TimeSeries", "FluSight-baseline", "FluSight-ensemble", + "FluSight-base_seasonal", "UMass-flusion" + ) +) + +# Function to fetch NHSN data +get_nhsn_data <- function() { + if (wday(Sys.Date()) < 6 & wday(Sys.Date()) > 3) { + # download from the preliminary data source from Wednesday to Friday + most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/mpgq-jmmr.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") + } else { + most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/ua7e-t2fy.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") + } + most_recent_result %>% + process_nhsn_data() %>% + filter(disease == "nhsn_flu") %>% + select(-disease) %>% + filter(geo_value %nin% insufficient_data_geos) %>% + mutate( + source = "nhsn", + geo_value = ifelse(geo_value == "usa", "us", geo_value), + time_value = time_value + ) %>% + filter(version == max(version)) %>% + select(-version) %>% + data_substitutions(disease = "flu") %>% + as_epi_df(other_keys = "source", as_of = Sys.Date()) +} + +# Function to fetch forecasts +get_forecasts <- function(days_back = 120) { + end_date <- Sys.Date() + start_date <- end_date - days_back + dates <- seq(get_reference_date(start_date), end_date, by = "week") + + all_forecasts <- map(config$forecasters, function(forecaster) { + map(dates, function(date) { + filename <- paste0(format(date, "%Y-%m-%d"), "-", forecaster, ".csv") + url <- paste0(config$base_url, "/", forecaster, "/", filename) + + tryCatch( + { + response <- GET(url) + if (status_code(response) == 200) { + read_csv(url, col_types = list( + reference_date = col_date(format = "%Y-%m-%d"), + target_end_date = col_date(format = "%Y-%m-%d"), + target = col_character(), + location = col_character(), + horizon = col_integer(), + output_type = col_character(), + output_type_id = col_character(), + value = col_double(), + forecaster = col_character(), + forecast_date = col_date(format = "%Y-%m-%d") + )) %>% + mutate( + forecaster = forecaster, + forecast_date = as.Date(date) + ) + } + }, + error = function(e) NULL + ) + }) %>% + bind_rows() + }) %>% + bind_rows() %>% + add_state_info(geo_value_col = "location", old_geo_code = "state_code", new_geo_code = "state_id") + + return(all_forecasts) +} + +score_forecasts <- function(all_forecasts, nhsn_latest_data) { + predictions_cards <- all_forecasts %>% + rename(model = forecaster) %>% + mutate( + quantile = as.numeric(output_type_id), + prediction = value + ) %>% + select(model, geo_value, forecast_date, target_end_date, quantile, prediction) + + truth_data <- nhsn_latest_data %>% + mutate( + target_end_date = as.Date(time_value), + true_value = value + ) %>% + select(geo_value, target_end_date, true_value) + + evaluate_predictions(predictions_cards = predictions_cards, truth_data = truth_data) %>% + rename(forecaster = model) +} + +# Fetch all data +nhsn_latest_data <- get_nhsn_data() + +# Create NHSN archive +nhsn_archive_data <- create_nhsn_data_archive(disease = "nhsn_flu") + +# Fetch forecasts +all_forecasts_unfiltered <- get_forecasts(days_back = 120) +all_forecasts <- all_forecasts_unfiltered %>% + filter( + target == "wk inc flu hosp", + output_type == "quantile" + ) %>% + mutate( + geo_value = state_id + ) %>% + select(-location) + +# Score forecasts +all_scores <- score_forecasts(all_forecasts, nhsn_data) +``` + +```{r setup, include=FALSE} +library(flexdashboard) +library(tidyverse) +library(plotly) +library(shiny) +library(DT) +library(epiprocess) + +# Setup the data +nhsn_data <- nhsn_latest_data +nhsn_historical <- nhsn_archive_data +forecast_data <- all_forecasts %>% + filter(output_type_id %in% c("0.05", "0.5", "0.95")) %>% + pivot_wider(names_from = output_type_id, values_from = value) +scores_data <- all_scores + +# Pre-compute the historical data as-of each forecast date +unique_forecast_dates <- sort(unique(forecast_data$forecast_date)) +nhsn_historical_as_ofs <- map(unique_forecast_dates, ~ { + nhsn_historical %>% + epix_as_of(as.Date(.x)) %>% + mutate(time_value = time_value) +}) +``` + +# Column {.sidebar data-width=300} + +```{r} +# Location selector +selectInput("location", "Select Location", + choices = unique(nhsn_data$geo_value), + selected = "us" +) + +# Date range for historical data +dateRangeInput("date_range", "Select Date Range", + start = as.Date("2023-07-01"), + end = max(nhsn_data$time_value), + max = max(nhsn_data$time_value) +) + +# Add checkboxes for forecaster selection +checkboxGroupInput("forecasters", "Select Forecasters:", + choices = unique(forecast_data$forecaster), + selected = c("FluSight-ensemble", "CMU-TimeSeries") +) + +div( + style = "text-align: center;", + p("Forecast Date"), + div( + style = "display: inline-block;", + actionButton("prev_date", "←"), + actionButton("next_date", "→") + ), + textOutput("current_forecast_date") +) + +current_date_index <- reactiveVal(length(unique_forecast_dates)) + +observeEvent(input$next_date, { + current_date_index(min(length(unique_forecast_dates), current_date_index() + 1)) +}) + +observeEvent(input$prev_date, { + current_date_index(max(1, current_date_index() - 1)) +}) + +output$current_forecast_date <- renderText({ + unique_forecast_dates[current_date_index()] +}) + +# Throttole to avoid re-computing the plot too often +plot_inputs <- reactive({ + list( + location = input$location, + date_range = input$date_range, + forecasters = input$forecasters, + forecast_date = unique_forecast_dates[current_date_index()] + ) +}) %>% debounce(500) + +plot_result <- reactive({ + req(plot_inputs()) + inputs <- plot_inputs() + + data_subset <- forecast_data %>% + filter( + forecaster %in% input$forecasters, + forecast_date == inputs$forecast_date, + geo_value == inputs$location + ) + + # Base data (light grey) + p <- ggplot() + + geom_line( + data = nhsn_data %>% + filter( + geo_value == inputs$location, + inputs$date_range[1] <= time_value, + time_value <= inputs$date_range[2] + ), + aes(x = time_value, y = value), + color = "grey80" + ) + + + # As-of data (black) + geom_line( + data = nhsn_historical_as_ofs[[current_date_index()]] %>% + filter( + geo_value == inputs$location, + inputs$date_range[1] <= time_value, + time_value <= inputs$date_range[2] + ), + aes(x = time_value, y = value), + color = "black" + ) + + + # Median line + geom_line( + data = data_subset, + aes(x = target_end_date, y = .data$`0.5`, color = forecaster) + ) + + + # Confidence interval + geom_ribbon( + data = data_subset, + aes( + x = target_end_date, + ymin = .data$`0.05`, + ymax = .data$`0.95`, + fill = forecaster + ), + alpha = 0.5 + ) + + theme_minimal() + + labs( + x = "Date", y = "Hospitalizations", + title = paste("Disease Surveillance for", toupper(inputs$location)) + ) + + ggplotly(p) +}) %>% bindCache(plot_inputs()) +``` + +# Forecast Fan Plots {data-width=750} + +```{r} +renderPlotly({ + plot_result() +}) +``` + +# Performance Metrics {data-width=350} + +```{r} +renderDT({ + # Calculate performance metrics for selected forecaster + performance_metrics <- scores_data %>% + filter( + geo_value == input$location + ) %>% + group_by(ahead, forecaster) %>% + summarise( + wis = round(mean(wis), 2), + ae = round(mean(ae), 2), + coverage_80 = round(mean(coverage_80), 2) + ) %>% + ungroup() %>% + arrange(ahead, wis) + + datatable(performance_metrics, + options = list( + pageLength = 50, + scrollY = TRUE + ) + ) +}) +``` diff --git a/scripts/serve-dashboard.r b/scripts/serve-dashboard.r new file mode 100644 index 00000000..769ca5ab --- /dev/null +++ b/scripts/serve-dashboard.r @@ -0,0 +1,22 @@ +# Run the dashboard +library(targets) +library(here) +source(here::here("R", "load_all.R")) + +Sys.setenv(TAR_PROJECT = "dashboard-proj") + +tar_make() +rmarkdown::run( + file = here::here("scripts", "reports", "forecast_dashboard.Rmd"), + render_args = list( + params = list( + nhsn_latest_data = tar_read(nhsn_latest_data), + nhsn_archive_data = tar_read(nhsn_archive_data), + all_forecasts = tar_read(all_forecasts), + all_scores = tar_read(all_scores) + ) + ) +) + +# Once ready, do this +rsconnect::deployDoc(doc = here::here("scripts", "reports", "forecast_dashboard.Rmd"))