Skip to content

Commit 30eb7f8

Browse files
committed
autoplot new data
1 parent 9836aff commit 30eb7f8

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

R/autoplot.R

+34-28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ ggplot2::autoplot
1616
#' @param object An `epi_workflow`
1717
#' @param predictions A data frame with predictions. If `NULL`, only the
1818
#' original data is shown.
19+
#' @param plot_data An epi_df of the data to plot against. This is for the case
20+
#' where you have the actual results to compare the forecast against.
1921
#' @param .levels A numeric vector of levels to plot for any prediction bands.
2022
#' More than 3 levels begins to be difficult to see.
2123
#' @param ... Ignored
@@ -84,7 +86,9 @@ NULL
8486
#' @export
8587
#' @rdname autoplot-epipred
8688
autoplot.epi_workflow <- function(
87-
object, predictions = NULL,
89+
object,
90+
predictions = NULL,
91+
plot_data = NULL,
8892
.levels = c(.5, .8, .95), ...,
8993
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
9094
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
@@ -111,30 +115,32 @@ autoplot.epi_workflow <- function(
111115
}
112116
keys <- c("geo_value", "time_value", "key")
113117
mold_roles <- names(mold$extras$roles)
114-
edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
115-
if (starts_with_impl("ahead_", names(y))) {
116-
old_name_y <- unlist(strsplit(names(y), "_"))
117-
shift <- as.numeric(old_name_y[2])
118-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
119-
edf <- rename(edf, !!new_name_y := !!names(y))
120-
} else if (starts_with_impl("lag_", names(y))) {
121-
old_name_y <- unlist(strsplit(names(y), "_"))
122-
shift <- -as.numeric(old_name_y[2])
123-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
124-
edf <- rename(edf, !!new_name_y := !!names(y))
125-
}
126-
127-
if (!is.null(shift)) {
128-
edf <- mutate(edf, time_value = time_value + shift)
118+
# extract the relevant column names for plotting
119+
old_name_y <- unlist(strsplit(names(y), "_"))
120+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
121+
if (is.null(plot_data)) {
122+
# the outcome has shifted, so we need to shift it forward (or back)
123+
# by the corresponding amount
124+
plot_data <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
125+
if (starts_with_impl("ahead_", names(y))) {
126+
shift <- as.numeric(old_name_y[2])
127+
} else if (starts_with_impl("lag_", names(y))) {
128+
old_name_y <- unlist(strsplit(names(y), "_"))
129+
shift <- -as.numeric(old_name_y[2])
130+
}
131+
plot_data <- rename(plot_data, !!new_name_y := !!names(y))
132+
if (!is.null(shift)) {
133+
plot_data <- mutate(plot_data, time_value = time_value + shift)
134+
}
135+
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
136+
plot_data <- as_epi_df(plot_data,
137+
as_of = object$fit$meta$as_of,
138+
other_keys = other_keys
139+
)
129140
}
130-
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
131-
edf <- as_epi_df(edf,
132-
as_of = object$fit$meta$as_of,
133-
other_keys = other_keys
134-
)
135141
if (is.null(predictions)) {
136142
return(autoplot(
137-
edf, new_name_y,
143+
plot_data, new_name_y,
138144
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
139145
.max_facets = .max_facets
140146
))
@@ -146,27 +152,27 @@ autoplot.epi_workflow <- function(
146152
}
147153
predictions <- rename(predictions, time_value = target_date)
148154
}
149-
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf))
155+
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(plot_data))
150156
if (!pred_cols_ok$ok) {
151157
cli_warn(c(
152158
"`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.",
153159
i = "Plotting the original data."
154160
))
155161
return(autoplot(
156-
edf, !!new_name_y,
162+
plot_data, !!new_name_y,
157163
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
158164
.max_facets = .max_facets
159165
))
160166
}
161167

162168
# First we plot the history, always faceted by everything
163-
bp <- autoplot(edf, !!new_name_y,
169+
bp <- autoplot(plot_data, !!new_name_y,
164170
.color_by = "none", .facet_by = "all_keys",
165171
.base_color = "black", .max_facets = .max_facets
166172
)
167173

168174
# Now, prepare matching facets in the predictions
169-
ek <- epi_keys_only(edf)
175+
ek <- epi_keys_only(plot_data)
170176
predictions <- predictions %>%
171177
mutate(
172178
.facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"),
@@ -204,7 +210,7 @@ autoplot.epi_workflow <- function(
204210
#' @export
205211
#' @rdname autoplot-epipred
206212
autoplot.canned_epipred <- function(
207-
object, ...,
213+
object, plot_data = NULL, ...,
208214
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
209215
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
210216
.base_color = "dodgerblue4",
@@ -218,7 +224,7 @@ autoplot.canned_epipred <- function(
218224
predictions <- object$predictions %>%
219225
rename(time_value = target_date)
220226

221-
autoplot(ewf, predictions,
227+
autoplot(ewf, predictions, plot_data, ...,
222228
.color_by = .color_by, .facet_by = .facet_by,
223229
.base_color = .base_color, .max_facets = .max_facets
224230
)

0 commit comments

Comments
 (0)