Skip to content

feat: add seasonal window filtering to step_training_window #461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions R/step_training_window.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
#' @param n_recent An integer value that represents the number of most recent
#' observations that are to be kept in the training window per key
#' The default value is 50.
#' @param seasonal Bool, default FALSE. If TRUE, the training window will slice
#' through epidemic seasons. This is useful for forecasting models that need
#' to leverage the data in previous years, but only limited to similar phases
#' in the epidemic. Most useful to heavily seasonal data, like influenza.
#' Expects n_recent to be finite.
#' @param seasonal_forward_window An integer value that represents the number of days
#' after a season week to include in the training window. The default value
#' is 14. Only valid when seasonal is TRUE.
#' @param seasonal_backward_window An integer value that represents the number of days
#' before a season week to include in the training window. The default value
#' is 35. Only valid when seasonal is TRUE.
#' @param epi_keys An optional character vector for specifying "key" variables
#' to group on. The default, `NULL`, ensures that every key combination is
#' limited.
Expand Down Expand Up @@ -42,10 +53,13 @@ step_training_window <-
function(recipe,
role = NA,
n_recent = 50,
seasonal = FALSE,
seasonal_forward_window = 14,
seasonal_backward_window = 35,
epi_keys = NULL,
id = rand_id("training_window")) {
arg_is_scalar(n_recent, id)
arg_is_pos(n_recent)
arg_is_scalar(n_recent, id, seasonal, seasonal_forward_window, seasonal_backward_window)
arg_is_pos(n_recent, seasonal_forward_window, seasonal_backward_window)
if (is.finite(n_recent)) arg_is_pos_int(n_recent)
arg_is_chr(id)
arg_is_chr(epi_keys, allow_null = TRUE)
Expand All @@ -55,6 +69,9 @@ step_training_window <-
role = role,
trained = FALSE,
n_recent = n_recent,
seasonal = seasonal,
seasonal_forward_window = seasonal_forward_window,
seasonal_backward_window = seasonal_backward_window,
epi_keys = epi_keys,
skip = TRUE,
id = id
Expand All @@ -63,12 +80,15 @@ step_training_window <-
}

step_training_window_new <-
function(role, trained, n_recent, epi_keys, skip, id) {
function(role, trained, n_recent, seasonal, seasonal_forward_window, seasonal_backward_window, epi_keys, skip, id) {
step(
subclass = "training_window",
role = role,
trained = trained,
n_recent = n_recent,
seasonal = seasonal,
seasonal_forward_window = seasonal_forward_window,
seasonal_backward_window = seasonal_backward_window,
epi_keys = epi_keys,
skip = skip,
id = id
Expand All @@ -86,6 +106,9 @@ prep.step_training_window <- function(x, training, info = NULL, ...) {
role = x$role,
trained = TRUE,
n_recent = x$n_recent,
seasonal = x$seasonal,
seasonal_forward_window = x$seasonal_forward_window,
seasonal_backward_window = x$seasonal_backward_window,
epi_keys = ek,
skip = x$skip,
id = x$id
Expand All @@ -104,15 +127,49 @@ bake.step_training_window <- function(object, new_data, ...) {
ungroup()
}

if (object$seasonal) {
# TODO: This needs to take into account different time types of time_value.
# Currently, it assumes time_value is a Date.
new_data <- new_data %>% add_season_info()
last_data_season_week <- new_data %>%
filter(time_value == max(time_value)) %>%
pull(season_week) %>%
max()
recent_weeks <- c(last_data_season_week)
if (inherits(new_data, "epi_df")) {
current_season_week <- convert_epiweek_to_season_week(epiyear(epi_as_of(new_data)), epiweek(epi_as_of(new_data)))
recent_weeks <- c(recent_weeks, current_season_week)
}
date_ranges <- new_data %>%
filter(season_week %in% recent_weeks) %>%
pull(time_value) %>%
unique() %>%
map(~ c(.x - 1:(object$seasonal_backward_window), .x + 0:(object$seasonal_forward_window))) %>%
unlist() %>%
as.Date() %>%
unique()
new_data %<>% filter(time_value %in% date_ranges)
}


new_data
}

#' @export
print.step_training_window <-
function(x, width = max(20, options()$width - 30), ...) {
title <- "# of recent observations per key limited to:"
n_recent <- x$n_recent
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width)
recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width)
if (x$seasonal) {
title <- "# of seasonal observations per key limited to:"
n_recent <- x$n_recent
seasonal_forward_window <- x$seasonal_forward_window
seasonal_backward_window <- x$seasonal_backward_window
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), width)
recipes::print_step(tr_obj, rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), x$trained, title, width)
} else {
title <- "# of recent observations per key limited to:"
n_recent <- x$n_recent
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width)
recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width)
}
invisible(x)
}
Loading