diff --git a/R/step_training_window.R b/R/step_training_window.R index eafc076c..812797d7 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -8,6 +8,17 @@ #' @param n_recent An integer value that represents the number of most recent #' observations that are to be kept in the training window per key #' The default value is 50. +#' @param seasonal Bool, default FALSE. If TRUE, the training window will slice +#' through epidemic seasons. This is useful for forecasting models that need +#' to leverage the data in previous years, but only limited to similar phases +#' in the epidemic. Most useful to heavily seasonal data, like influenza. +#' Expects n_recent to be finite. +#' @param seasonal_forward_window An integer value that represents the number of days +#' after a season week to include in the training window. The default value +#' is 14. Only valid when seasonal is TRUE. +#' @param seasonal_backward_window An integer value that represents the number of days +#' before a season week to include in the training window. The default value +#' is 35. Only valid when seasonal is TRUE. #' @param epi_keys An optional character vector for specifying "key" variables #' to group on. The default, `NULL`, ensures that every key combination is #' limited. @@ -42,10 +53,13 @@ step_training_window <- function(recipe, role = NA, n_recent = 50, + seasonal = FALSE, + seasonal_forward_window = 14, + seasonal_backward_window = 35, epi_keys = NULL, id = rand_id("training_window")) { - arg_is_scalar(n_recent, id) - arg_is_pos(n_recent) + arg_is_scalar(n_recent, id, seasonal, seasonal_forward_window, seasonal_backward_window) + arg_is_pos(n_recent, seasonal_forward_window, seasonal_backward_window) if (is.finite(n_recent)) arg_is_pos_int(n_recent) arg_is_chr(id) arg_is_chr(epi_keys, allow_null = TRUE) @@ -55,6 +69,9 @@ step_training_window <- role = role, trained = FALSE, n_recent = n_recent, + seasonal = seasonal, + seasonal_forward_window = seasonal_forward_window, + seasonal_backward_window = seasonal_backward_window, epi_keys = epi_keys, skip = TRUE, id = id @@ -63,12 +80,15 @@ step_training_window <- } step_training_window_new <- - function(role, trained, n_recent, epi_keys, skip, id) { + function(role, trained, n_recent, seasonal, seasonal_forward_window, seasonal_backward_window, epi_keys, skip, id) { step( subclass = "training_window", role = role, trained = trained, n_recent = n_recent, + seasonal = seasonal, + seasonal_forward_window = seasonal_forward_window, + seasonal_backward_window = seasonal_backward_window, epi_keys = epi_keys, skip = skip, id = id @@ -86,6 +106,9 @@ prep.step_training_window <- function(x, training, info = NULL, ...) { role = x$role, trained = TRUE, n_recent = x$n_recent, + seasonal = x$seasonal, + seasonal_forward_window = x$seasonal_forward_window, + seasonal_backward_window = x$seasonal_backward_window, epi_keys = ek, skip = x$skip, id = x$id @@ -104,15 +127,49 @@ bake.step_training_window <- function(object, new_data, ...) { ungroup() } + if (object$seasonal) { + # TODO: This needs to take into account different time types of time_value. + # Currently, it assumes time_value is a Date. + new_data <- new_data %>% add_season_info() + last_data_season_week <- new_data %>% + filter(time_value == max(time_value)) %>% + pull(season_week) %>% + max() + recent_weeks <- c(last_data_season_week) + if (inherits(new_data, "epi_df")) { + current_season_week <- convert_epiweek_to_season_week(epiyear(epi_as_of(new_data)), epiweek(epi_as_of(new_data))) + recent_weeks <- c(recent_weeks, current_season_week) + } + date_ranges <- new_data %>% + filter(season_week %in% recent_weeks) %>% + pull(time_value) %>% + unique() %>% + map(~ c(.x - 1:(object$seasonal_backward_window), .x + 0:(object$seasonal_forward_window))) %>% + unlist() %>% + as.Date() %>% + unique() + new_data %<>% filter(time_value %in% date_ranges) + } + + new_data } #' @export print.step_training_window <- function(x, width = max(20, options()$width - 30), ...) { - title <- "# of recent observations per key limited to:" - n_recent <- x$n_recent - tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width) - recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width) + if (x$seasonal) { + title <- "# of seasonal observations per key limited to:" + n_recent <- x$n_recent + seasonal_forward_window <- x$seasonal_forward_window + seasonal_backward_window <- x$seasonal_backward_window + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), width) + recipes::print_step(tr_obj, rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), x$trained, title, width) + } else { + title <- "# of recent observations per key limited to:" + n_recent <- x$n_recent + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width) + recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width) + } invisible(x) }