Skip to content

Commit 6d1ca2a

Browse files
authored
Merge pull request #51 from cmu-delphi/ensemble
ensemble models
2 parents 1c55ae2 + 0bd0b55 commit 6d1ca2a

29 files changed

+706
-60
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
on:
4+
workflow_dispatch:
45
push:
56
branches: [main]
67
pull_request:

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@
44
tmp/
55
extras/**.html
66
*.pdf
7-
.Renviron
7+
.Renviron
8+
.renvignore
9+
nohup.out
10+
run.Rout
11+
tmp.R

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epieval
22
Title: Evaluating Timeseries Forecasting on Archival Data
3-
Version: 0.1.0
3+
Version: 0.2.0
44
Date: 2023-09-28
55
Authors@R:
66
c(

NAMESPACE

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,25 @@ export(clear_lastminute_nas)
88
export(collapse_cards)
99
export(confirm_sufficient_data)
1010
export(covidhub_probs)
11+
export(ensemble_average)
12+
export(ensemble_missing_forecasters)
13+
export(ensemble_missing_forecasters_details)
1114
export(evaluate_predictions)
1215
export(extend_ahead)
1316
export(flatline_fc)
17+
export(forecaster_lookup)
1418
export(forecaster_pred)
1519
export(format_storage)
1620
export(id_ahead_ensemble_grid)
1721
export(interval_coverage)
18-
export(lookup_ids)
1922
export(make_data_targets)
20-
export(make_ensemble_targets)
23+
export(make_ensemble_targets_and_scores)
2124
export(make_external_names_and_scores)
2225
export(make_forecasts_and_scores)
2326
export(make_forecasts_and_scores_by_ahead)
27+
export(make_shared_ensembles)
2428
export(make_shared_grids)
29+
export(make_target_ensemble_grid)
2530
export(make_target_param_grid)
2631
export(manage_S3_forecast_cache)
2732
export(overprediction)
@@ -83,16 +88,21 @@ importFrom(epiprocess,epix_slide)
8388
importFrom(here,here)
8489
importFrom(magrittr,"%<>%")
8590
importFrom(magrittr,"%>%")
91+
importFrom(purrr,imap)
8692
importFrom(purrr,map)
8793
importFrom(purrr,map2_vec)
94+
importFrom(purrr,map_vec)
8895
importFrom(purrr,transpose)
8996
importFrom(recipes,all_numeric)
9097
importFrom(rlang,"!!")
98+
importFrom(rlang,"%||%")
9199
importFrom(rlang,.data)
92100
importFrom(rlang,quo)
93101
importFrom(rlang,sym)
94102
importFrom(rlang,syms)
103+
importFrom(targets,tar_config_get)
95104
importFrom(targets,tar_group)
105+
importFrom(targets,tar_read)
96106
importFrom(targets,tar_target)
97107
importFrom(tibble,tibble)
98108
importFrom(tidyr,drop_na)

R/ensemble_average.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#' an ensemble model that averages each quantile separately
2+
#' @description
3+
#' The simplest class of ensembing models, it takes in a list of quantile
4+
#' forecasts and averages them on a per-quantile basis. By default the average
5+
#' used is the median, but it can accept any vectorized function.
6+
#' @param epi_data unused for this forecaster, but potentially an ensemble may
7+
#' want the underlying data.
8+
#' @param outcome The name of the target variable.
9+
#' @param extra_sources The name of any extra columns to use. This list could be
10+
#' empty
11+
#' @param forecasts a list of quantile forecasts to aggregate. They should
12+
#' be tibbles with columns `(geo_value, forecast_date, target_end_date,
13+
#' quantile, value)`, preferably in that order.
14+
#' @param ensemble_args any arguments unique to this particular ensembler should
15+
#' be included in a list like this (unfortunate targets issues). The arguments
16+
#' for `ensemble_average` in particular are `average_type` and `join_columns`
17+
#' @param ensemble_args_names an argument purely for use in targets. You
18+
#' probably shouldn't worry about it. In a target, it should probably be
19+
#' `ensemble_args_names = names(ensemble_args)`
20+
#' @importFrom rlang %||%
21+
#' @export
22+
ensemble_average <- function(epi_data,
23+
forecasts,
24+
outcome,
25+
extra_sources = "",
26+
ensemble_args = list(),
27+
ensemble_args_names = NULL) {
28+
# unique parameters must be buried in ensemble_args so that the generic function signature is stable
29+
# their names are separated for obscure target related reasons
30+
if (!is.null(ensemble_args_names)) {
31+
names(ensemble_args) <- ensemble_args_names
32+
}
33+
average_type <- ensemble_args$average_type %||% median
34+
join_columns <- ensemble_args$join_columns %||% c("geo_value", "forecast_date", "target_end_date", "quantile")
35+
# begin actual analysis
36+
bind_rows(!!!forecasts, .id = "forecaster") %>%
37+
group_by(across(all_of(join_columns))) %>%
38+
summarize(value = average_type(value)) %>%
39+
ungroup()
40+
}

R/epieval-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' @importFrom magrittr %>% %<>%
22
#' @importFrom dplyr select rename inner_join join_by mutate relocate any_of
33
#' group_by reframe summarize left_join across filter rowwise everything ungroup
4-
#' @importFrom purrr transpose map map2_vec
4+
#' @importFrom purrr transpose map map2_vec map_vec imap
55
#' @keywords internal
66
"_PACKAGE"
77
globalVariables(c("ahead", "id", "parent_id", "all_of", "last_col", "time_value", "geo_value", "target_end_date", "forecast_date", "quantile", ".pred_distn", "quantiles", "quantile_levels", "signal", ".dstn", "values", ".", "forecasters", "forecaster", "trainer", "forecast_date", ".pred", "n_distinct", "target_date", "value"))

R/targets_utils.R

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
#' @export
77
#' @importFrom rlang syms
88
make_target_param_grid <- function(param_grid) {
9+
not_na <- !is.na(param_grid$trainer)
10+
param_grid$trainer[not_na] <- syms(param_grid$trainer[not_na])
911
param_grid %<>%
1012
select(-any_of("parent_id")) %>%
11-
mutate(forecaster = syms(forecaster)) %>%
12-
mutate(trainer = syms(trainer))
13+
mutate(forecaster = syms(forecaster))
1314
list_of_params <- lists_of_real_values(param_grid)
1415
list_names <- map(list_of_params, names)
1516
tibble(
@@ -19,6 +20,30 @@ make_target_param_grid <- function(param_grid) {
1920
param_names = list_names
2021
)
2122
}
23+
#' convert a list of forecasters
24+
#' @description
25+
#' the required format for targets is a little jank; this takes a human legible tibble and makes it targets legible.
26+
#' Currently only `forecaster` and `trainer` can be symbols.
27+
#' @param param_grid the tibble of parameters. Must have forecaster and trainer, everything else is optional
28+
#' @param ONE_AHEAD_FORECASTER_NAME the extra bit of name that is shared by all
29+
#' @export
30+
#' @importFrom rlang syms
31+
make_target_ensemble_grid <- function(param_grid, ONE_AHEAD_FORECASTER_NAME = "forecast_by_ahead") {
32+
param_grid$ensemble_params <- map(param_grid$ensemble_params, sym_subset)
33+
param_grid %<>%
34+
mutate(ensemble = syms(ensemble)) %>%
35+
mutate(ensemble_params_names = list(names(ensemble_params))) %>%
36+
select(-forecasters) %>%
37+
relocate(id, .before = everything()) %>%
38+
mutate(forecaster_ids = list(syms(paste(ONE_AHEAD_FORECASTER_NAME, forecaster_ids, sep = "_"))))
39+
return(param_grid)
40+
}
41+
#' function to map
42+
#' @keywords internal
43+
#' @param sym_names a list of the parameter names that should be turned into symbols
44+
sym_subset <- function(param_list, sym_names = list("average_type")) {
45+
imap(param_list, \(x, y) if (y %in% sym_names) sym(x) else x)
46+
}
2247

2348
#' helper function for `make_target_param_grid`
2449
#' @keywords internal
@@ -150,7 +175,7 @@ make_data_targets <- function() {
150175
)
151176
}
152177

153-
#' Make common targets for forecasting experiments
178+
#' Make list of common forecasters for forecasting experiments across projects
154179
#' @export
155180
make_shared_grids <- function() {
156181
list(
@@ -163,12 +188,44 @@ make_shared_grids <- function() {
163188
tidyr::expand_grid(
164189
forecaster = "scaled_pop",
165190
trainer = c("linreg", "quantreg"),
166-
ahead = 5:7,
191+
ahead = 1:7,
167192
lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14)),
168193
pop_scaling = c(FALSE)
194+
),
195+
tidyr::expand_grid(
196+
forecaster = "flatline_fc",
197+
ahead = 1:7
169198
)
170199
)
171200
}
201+
#' Make list of common ensembles for forecasting experiments across projects
202+
#' @export
203+
make_shared_ensembles <- function() {
204+
ex_forecaster <- list(
205+
forecaster = "scaled_pop",
206+
trainer = "linreg",
207+
pop_scaling = FALSE,
208+
lags = c(0, 3, 5, 7, 14)
209+
)
210+
# ensembles don't lend themselves to expand grid (inherently needs a list for sub-forecasters)
211+
tribble(
212+
~ensemble, ~ensemble_params, ~forecasters,
213+
# mean forecaster
214+
"ensemble_average",
215+
list(average_type = "mean"),
216+
list(
217+
ex_forecaster,
218+
list(forecaster = "flatline_fc")
219+
),
220+
# median forecaster
221+
"ensemble_average",
222+
list(average_type = "median"),
223+
list(
224+
ex_forecaster,
225+
list(forecaster = "flatline_fc")
226+
),
227+
)
228+
}
172229

173230
#' Make forecasts and scores by ahead targets
174231
#' @description
@@ -238,8 +295,25 @@ make_forecasts_and_scores <- function() {
238295

239296
#' Make ensemble targets
240297
#' @export
241-
make_ensemble_targets <- function() {
242-
list()
298+
make_ensemble_targets_and_scores <- function() {
299+
ensembles_and_scores <- tar_map(
300+
values = ensemble_parent_id_map,
301+
names = parent_id,
302+
tar_target(
303+
name = ensemble,
304+
command = {
305+
bind_rows(ensemble_component_ids) %>%
306+
mutate(parent_ensemble = parent_id)
307+
}
308+
),
309+
tar_target(
310+
name = ensemble_score,
311+
command = {
312+
bind_rows(score_component_ids) %>%
313+
mutate(parent_ensemble = parent_id)
314+
}
315+
)
316+
)
243317
}
244318

245319

R/utils.R

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,75 @@ add_id <- function(df, n_adj = 2) {
3636
return(df)
3737
}
3838

39+
#' look up forecasters by name
40+
#' @description
41+
#' given a (partial) forecaster name, look up all forecasters in the given project which contain part of that name.
42+
#' @param forecaster_name a part of the adj.adj.1 name used to identify the forecaster.
43+
#' @param param_grid the tibble containing the mapping between
44+
#' @param project the project to be used; by default, the environmental variable is used
45+
#' @importFrom targets tar_read tar_config_get
46+
#' @export
47+
forecaster_lookup <- function(forecaster_name, param_grid = NULL, project = NULL) {
48+
forecaster_name <- strip_underscored(forecaster_name)
49+
if (is.null(project)) {
50+
project <- tar_config_get("store")
51+
}
52+
if (is.null(param_grid)) {
53+
param_grid <- tar_read(forecaster_params_grid, store = project)
54+
}
55+
param_grid %>% filter(grepl(forecaster_name, id))
56+
}
57+
58+
strip_underscored <- function(x) {
59+
g <- gregexpr("_", x, fixed = TRUE)
60+
last_underscore <- g[[1]][[length(g[[1]])]]
61+
substr(x[[1]], start = last_underscore + 1, stop = nchar(x))
62+
}
63+
64+
#' list forecasters used in the given ensemble table not found in the given forecaster grid
65+
#' @description
66+
#' list forecasters used in the given ensemble table not found in the given forecaster grid
67+
#'
68+
#' @param ensemble_grid the grid of ensembles used
69+
#' @param param_grid the grid of forecasters used that we're checking for presence
70+
#' @param project the project to be used; by default, the environmental variable is used
71+
#' @export
72+
ensemble_missing_forecasters <- function(ensemble_grid = NULL, param_grid = NULL, project = NULL) {
73+
if (is.null(project)) {
74+
project <- tar_config_get("store")
75+
}
76+
if (is.null(ensemble_grid)) {
77+
ensemble_grid <- tar_read(ensemble_forecasters, store = project)
78+
}
79+
used_forecasters <- unlist(ensemble_grid$forecaster_ids) %>% unique()
80+
is_present <- map_vec(used_forecasters, \(given_forecaster) nrow(forecaster_lookup(given_forecaster, param_grid, project)) > 0)
81+
absent_forecasters <- used_forecasters[!is_present]
82+
return(absent_forecasters)
83+
}
84+
85+
#' given an ensemble and a list of forecasters used in some of those ensembles, return the ones that use them
86+
#' @inheritParams ensemble_missing_forecasters
87+
#' @export
88+
ensemble_missing_forecasters_details <- function(ensemble_grid = NULL, param_grid = NULL, project = NULL) {
89+
absent_forecasters <- ensemble_missing_forecasters(ensemble_grid, param_grid, project)
90+
grid_with_missing <- ensemble_grid %>%
91+
rowwise() %>%
92+
mutate(
93+
missing_forecasters = list(map(
94+
absent_forecasters,
95+
# extract a list of the subforecasters with associated id, with only the missing ones having non-empty lists
96+
function(absent_fc) {
97+
is_missing <- grepl(absent_fc, forecaster_ids)
98+
params_only <- forecasters[is_missing]
99+
mapply(c, params_only, id = forecaster_ids[is_missing])
100+
}
101+
))
102+
)
103+
flat_missing <- unlist(grid_with_missing$missing_forecasters, recursive = FALSE)
104+
unique_missing <- flat_missing[map_vec(flat_missing, \(x) length(x) > 0)] %>% unique()
105+
return(unique_missing)
106+
}
107+
39108

40109
#' generate an id from a simple list of parameters
41110
#' @param param_list the list of parameters. must include `ahead` if `ahead = NULL`
@@ -45,6 +114,7 @@ add_id <- function(df, n_adj = 2) {
45114
single_id <- function(param_list, ahead = NULL, n_adj = 2) {
46115
full_hash <- param_list[names(param_list) != "ahead"] %>%
47116
.[order(names(.))] %>% # put in alphabetical order
117+
lapply(function(x) if (length(x) > 1) list(x) else x) %>% # the tibble version needs vectors to actually be lists, so this is a conversion to make sure the strings are identical
48118
paste(collapse = "") %>%
49119
hash_animal(n_adj = n_adj)
50120
single_string <- full_hash$words[[1]][1:n_adj] %>% paste(sep = ".", collapse = ".")
@@ -56,16 +126,11 @@ single_id <- function(param_list, ahead = NULL, n_adj = 2) {
56126
return(full_name)
57127
}
58128

59-
60-
#' given target name(s), lookup the corresponding parameters
61-
#' @export
62-
lookup_ids <- function() {
63-
}
64-
65-
66129
#' add aheads, forecaster_ids, and ids to a list of ensemble models
67130
#' @description
68-
#' minor utility
131+
#' First, do an expand grid to do a full combination of ensemble_grid x aheads.
132+
#' Then add a column containing lists of ids of the dependent forecasters
133+
#' based on their parameters.
69134
#' @param ensemble_grid the list of ensembles,
70135
#' @param aheads the aheads to add
71136
#' @inheritParams add_id
@@ -82,6 +147,9 @@ id_ahead_ensemble_grid <- function(ensemble_grid, aheads, n_adj = 2) {
82147
add_id(., n_adj = 2) %>%
83148
rowwise() %>%
84149
mutate(forecaster_ids = list(map2_vec(forecasters, ahead, single_id, n_adj = 2)))
150+
if (length(ensemble_grid$id %>% unique()) < length(ensemble_grid$id)) {
151+
abort("ensemble grid has non-unique forecasters")
152+
}
85153
return(ensemble_grid)
86154
}
87155

app.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ shinyApp(
136136
server = function(input, output, session) {
137137
filtered_scorecards_reactive <- reactive({
138138
agg_forecasters <- unique(c(input$selected_forecasters, input$baseline))
139-
if (length(agg_forecasters) == 0) { return(data.frame()) }
139+
if (length(agg_forecasters) == 0 ||
140+
all(agg_forecasters == "" | is.null(agg_forecasters) | is.na(agg_forecasters))
141+
) {
142+
return(data.frame())
143+
}
140144

141145
processed_evaluations_internal <- lapply(agg_forecasters, function(forecaster) {
142146
load_forecast_data(forecaster) %>>%

0 commit comments

Comments
 (0)