Skip to content

fix: update for compatibility with epiprocess==0.9.0 #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
^DEVELOPMENT\.md$
^doc$
^Meta$
^.lintr$
^.lintr$
^.venv$
5 changes: 2 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.20
Version: 0.0.21
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand All @@ -23,8 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
epiprocess (>= 0.8.0),
epiprocess (< 0.9.0),
epiprocess (>= 0.9.0),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Expand Down
2 changes: 1 addition & 1 deletion R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ autoplot.epi_workflow <- function(
if (length(extra_keys) == 0L) extra_keys <- NULL
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
additional_metadata = list(other_keys = extra_keys)
other_keys = extra_keys %||% character()
)
if (is.null(predictions)) {
return(autoplot(
Expand Down
6 changes: 3 additions & 3 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
Expand All @@ -47,7 +47,7 @@
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' aes(x = time_value, y = deaths_7dsum)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
Expand Down
12 changes: 7 additions & 5 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ epi_recipe.epi_df <-
keys <- key_colnames(x) # we know x is an epi_df

var_info <- tibble(variable = vars)
key_roles <- c("geo_value", "time_value", rep("key", length(keys) - 2))
key_roles <- c("geo_value", rep("key", length(keys) - 2), "time_value")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to keep these in the same order as https://github.com/cmu-delphi/epiprocess/blob/16f38b2b386522f7f8f4880157432a3fa4a8d6ae/R/epi_df.R#L168.

Does this break something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is updated in our upcoming epiprocess PR. If this line isn't changed then the wrong columns are associated with the wrong roles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Is there a reason for that choice? I think the previous ordering was based on a Slack vote that @brookslogan took a while back? I can't seem to find it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the Slack vote asked for people's preference for the order to add arguments to a function, not how they want to order their epi_df.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Right, and this was about API data-fetching functions, rather than epiprocess ones. In arrange_canonical I think we just borrowed this vote without doing another one... and I do wonder about my polling methodology anyway.)


## Check and add roles when available
if (!is.null(roles)) {
Expand Down Expand Up @@ -499,8 +499,11 @@ prep.epi_recipe <- function(
if (!is_epi_df(training)) {
# tidymodels killed our class
# for now, we only allow step_epi_* to alter the metadata
training <- dplyr::dplyr_reconstruct(
as_epi_df(training), before_template
metadata <- attr(before_template, "metadata")
training <- as_epi_df(
training,
as_of = metadata$as_of,
other_keys = metadata$other_keys %||% character()
)
}
training <- dplyr::relocate(training, all_of(key_colnames(training)))
Expand Down Expand Up @@ -579,8 +582,7 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
new_data <- as_epi_df(
new_data,
as_of = meta$as_of,
# avoid NULL if meta is from saved older epi_df:
additional_metadata = meta$additional_metadata %||% list()
other_keys = meta$other_keys %||% character()
)
}
new_data
Expand Down
3 changes: 2 additions & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ is_epi_workflow <- function(x) {
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
object$fit$meta <- list(
max_time_value = max(data$time_value),
as_of = attributes(data)$metadata$as_of
as_of = attr(data, "metadata")$as_of,
other_keys = attr(data, "metadata")$other_keys
)
object$original_data <- data

Expand Down
4 changes: 2 additions & 2 deletions R/flusight_hub_formatter.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ abbr_to_location <- function(abbr) {
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
#' flusight_hub_formatter(cdc)
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
Expand Down
15 changes: 8 additions & 7 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
#' @export
key_colnames.recipe <- function(x, ...) {
possible_keys <- c("geo_value", "time_value", "key")
keys <- x$var_info$variable[x$var_info$role %in% possible_keys]
keys[order(match(keys, possible_keys))] %||% character(0L)
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
keys <- x$var_info$variable[x$var_info$role %in% "key"]
c(geo_key, keys, time_key) %||% character(0L)
}

#' @export
key_colnames.epi_workflow <- function(x, ...) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
possible_keys <- c("geo_value", "time_value", "key")
molded_names <- names(mold$extras$roles)
keys <- map(mold$extras$roles[molded_names %in% possible_keys], names)
keys <- unname(unlist(keys))
keys[order(match(keys, possible_keys))] %||% character(0L)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
c(geo_key, keys, time_key) %||% character(0L)
}

kill_time_value <- function(v) {
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ slather.layer_add_forecast_date <- function(object, components, workflow,
workflows::extract_preprocessor(workflow)$template, "metadata"
)$time_type
if (expected_time_type == "week") expected_time_type <- "day"
if (expected_time_type == "integer") expected_time_type <- "year"
validate_date(
forecast_date, expected_time_type,
call = rlang::expr(layer_add_forecast_date())
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ slather.layer_add_target_date <- function(object, components, workflow,
workflows::extract_preprocessor(workflow)$template, "metadata"
)$time_type
if (expected_time_type == "week") expected_time_type <- "day"
if (expected_time_type == "integer") expected_time_type <- "year"

if (!is.null(object$target_date)) {
target_date <- object$target_date
Expand Down
121 changes: 64 additions & 57 deletions R/step_epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
#' argument must be named `.x`. A common, though very difficult to debug
#' error is using something like `function(x) mean`. This will not work
#' because it returns the function mean, rather than `mean(x)`
#' @param before,after the size of the sliding window on the left and the right
#' of the center. Usually non-negative integers for data indexed by date, but
#' more restrictive in other cases (see [epiprocess::epi_slide()] for details).
#' @param f_name a character string of at most 20 characters that describes
#' the function. This will be combined with `prefix` and the columns in `...`
#' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined
#' automatically using `clean_f_name()`.
#' @param .window_size the size of the sliding window, required. Usually a
#' non-negative integer will suffice (e.g. for data indexed by date, but more
#' restrictive in other time_type cases (see [epiprocess::epi_slide()] for
#' details). For example, set to 7 for a 7-day window.
#' @param .align a character string indicating how the window should be aligned.
#' By default, this is "right", meaning the slide_window will be anchored with
#' its right end point on the reference date. (see [epiprocess::epi_slide()]
#' for details).
#' @param f_name a character string of at most 20 characters that describes the
#' function. This will be combined with `prefix` and the columns in `...` to
#' name the result using `{prefix}{f_name}_{column}`. By default it will be
#' determined automatically using `clean_f_name()`.
#'
#' @template step-return
#'
Expand All @@ -37,53 +42,55 @@
#' rec <- epi_recipe(jhu) %>%
#' step_epi_slide(case_rate, death_rate,
#' .f = \(x) mean(x, na.rm = TRUE),
#' before = 6L
#' .window_size = 7L
#' )
#' bake(prep(rec, jhu), new_data = NULL)
step_epi_slide <-
function(recipe,
...,
.f,
before = 0L,
after = 0L,
role = "predictor",
prefix = "epi_slide_",
f_name = clean_f_name(.f),
skip = FALSE,
id = rand_id("epi_slide")) {
if (!is_epi_recipe(recipe)) {
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
}
.f <- validate_slide_fun(.f)
epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type)
epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)
step_epi_slide <- function(recipe,
...,
.f,
.window_size = NULL,
.align = c("right", "center", "left"),
role = "predictor",
prefix = "epi_slide_",
f_name = clean_f_name(.f),
skip = FALSE,
id = rand_id("epi_slide")) {
if (!is_epi_recipe(recipe)) {
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
}
.f <- validate_slide_fun(.f)
if (is.null(.window_size)) {
cli_abort("step_epi_slide: `.window_size` must be specified.")
}
epiprocess:::validate_slide_window_arg(.window_size, attributes(recipe$template)$metadata$time_type)
.align <- rlang::arg_match(.align)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)

recipes::add_step(
recipe,
step_epi_slide_new(
terms = enquos(...),
before = before,
after = after,
.f = .f,
f_name = f_name,
role = role,
trained = FALSE,
prefix = prefix,
keys = key_colnames(recipe),
columns = NULL,
skip = skip,
id = id
)
recipes::add_step(
recipe,
step_epi_slide_new(
terms = enquos(...),
.window_size = .window_size,
.align = .align,
.f = .f,
f_name = f_name,
role = role,
trained = FALSE,
prefix = prefix,
keys = key_colnames(recipe),
columns = NULL,
skip = skip,
id = id
)
}
)
}


step_epi_slide_new <-
function(terms,
before,
after,
.window_size,
.align,
.f,
f_name,
role,
Expand All @@ -96,8 +103,8 @@ step_epi_slide_new <-
recipes::step(
subclass = "epi_slide",
terms = terms,
before = before,
after = after,
.window_size = .window_size,
.align = .align,
.f = .f,
f_name = f_name,
role = role,
Expand All @@ -119,8 +126,8 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {

step_epi_slide_new(
terms = x$terms,
before = x$before,
after = x$after,
.window_size = x$.window_size,
.align = x$.align,
.f = x$.f,
f_name = x$f_name,
role = x$role,
Expand Down Expand Up @@ -165,8 +172,8 @@ bake.step_epi_slide <- function(object, new_data, ...) {
# }
epi_slide_wrapper(
new_data,
object$before,
object$after,
object$.window_size,
object$.align,
object$columns,
c(object$.f),
object$f_name,
Expand All @@ -190,7 +197,7 @@ bake.step_epi_slide <- function(object, new_data, ...) {
#' @importFrom dplyr bind_cols group_by ungroup
#' @importFrom epiprocess epi_slide
#' @keywords internal
epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, group_keys, name_prefix) {
epi_slide_wrapper <- function(new_data, .window_size, .align, columns, fns, fn_names, group_keys, name_prefix) {
cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns)
# Iterate over the rows of cols_fns. For each row number, we will output a
# transformed column. The first result returns all the original columns along
Expand All @@ -204,10 +211,10 @@ epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, g
result <- new_data %>%
group_by(across(all_of(group_keys))) %>%
epi_slide(
before = before,
after = after,
new_col_name = result_name,
f = function(slice, geo_key, ref_time_value) {
.window_size = .window_size,
.align = .align,
.new_col_name = result_name,
.f = function(slice, geo_key, ref_time_value) {
fn(slice[[col_name]])
}
) %>%
Expand Down
16 changes: 7 additions & 9 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,26 @@ check_pname <- function(res, preds, object, newname = NULL) {


grab_forged_keys <- function(forged, workflow, new_data) {
keys <- c("geo_value", "time_value", "key")
forged_roles <- names(forged$extras$roles)
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys])
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
# 1. these are the keys in the test data after prep/bake
new_keys <- names(extras)
# 2. these are the keys in the training data
old_keys <- key_colnames(workflow)
# 3. these are the keys in the test data as input
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, keys[1:2]))
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
cli::cli_warn(c(
"Not all epi keys that were present in the training data are available",
"in `new_data`. Predictions will have only the available keys."
))
}
if (is_epi_df(new_data)) {
extras <- as_epi_df(extras)
attr(extras, "metadata") <- attr(new_data, "metadata")
} else if (all(keys[1:2] %in% new_keys)) {
l <- list()
if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)])
extras <- as_epi_df(extras, additional_metadata = l)
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
}
extras
}
Expand Down
2 changes: 1 addition & 1 deletion data-raw/grad_employ_subset.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ ncol(gemploy)
grad_employ_subset <- gemploy %>%
as_epi_df(
as_of = "2022-07-19",
additional_metadata = list(other_keys = c("age_group", "edu_qual"))
other_keys = c("age_group", "edu_qual")
)
usethis::use_data(grad_employ_subset, overwrite = TRUE)
Binary file modified data/grad_employ_subset.rda
Binary file not shown.
2 changes: 2 additions & 0 deletions man/autoplot-epipred.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading