2626# ' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
2727# '
2828# ' @examples
29+ # ' library(dplyr)
2930# ' jhu <- case_death_rate_subset %>%
30- # ' dplyr:: filter(time_value >= as.Date("2021-11-01"))
31+ # ' filter(time_value >= as.Date("2021-11-01"))
3132# '
3233# ' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate"))
3334# '
@@ -45,23 +46,23 @@ arx_classifier <- function(
4546 epi_data ,
4647 outcome ,
4748 predictors ,
48- trainer = parsnip :: logistic_reg(),
49+ trainer = logistic_reg(),
4950 args_list = arx_class_args_list()) {
5051 if (! is_classification(trainer )) {
51- cli :: cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
52+ cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
5253 }
5354
5455 wf <- arx_class_epi_workflow(epi_data , outcome , predictors , trainer , args_list )
55- wf <- generics :: fit(wf , epi_data )
56+ wf <- fit(wf , epi_data )
5657
5758 preds <- forecast(
5859 wf ,
5960 fill_locf = TRUE ,
6061 n_recent = args_list $ nafill_buffer ,
6162 forecast_date = args_list $ forecast_date %|| % max(epi_data $ time_value )
6263 ) %> %
63- tibble :: as_tibble() %> %
64- dplyr :: select(- time_value )
64+ as_tibble() %> %
65+ select(- time_value )
6566
6667 structure(
6768 list (
@@ -95,17 +96,17 @@ arx_classifier <- function(
9596# ' @export
9697# ' @seealso [arx_classifier()]
9798# ' @examples
98- # '
99+ # ' library(dplyr)
99100# ' jhu <- case_death_rate_subset %>%
100- # ' dplyr:: filter(time_value >= as.Date("2021-11-01"))
101+ # ' filter(time_value >= as.Date("2021-11-01"))
101102# '
102103# ' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
103104# '
104105# ' arx_class_epi_workflow(
105106# ' jhu,
106107# ' "death_rate",
107108# ' c("case_rate", "death_rate"),
108- # ' trainer = parsnip:: multinom_reg(),
109+ # ' trainer = multinom_reg(),
109110# ' args_list = arx_class_args_list(
110111# ' breaks = c(-.05, .1), ahead = 14,
111112# ' horizon = 14, method = "linear_reg"
@@ -119,18 +120,18 @@ arx_class_epi_workflow <- function(
119120 args_list = arx_class_args_list()) {
120121 validate_forecaster_inputs(epi_data , outcome , predictors )
121122 if (! inherits(args_list , c(" arx_class" , " alist" ))) {
122- rlang :: abort( " args_list was not created using `arx_class_args_list()." )
123+ cli_abort( " ` args_list` was not created using `arx_class_args_list()` ." )
123124 }
124125 if (! (is.null(trainer ) || is_classification(trainer ))) {
125- rlang :: abort (" `trainer` must be a `{ parsnip}` model of mode 'classification'." )
126+ cli_abort (" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
126127 }
127128 lags <- arx_lags_validator(predictors , args_list $ lags )
128129
129130 # --- preprocessor
130131 # ------- predictors
131132 r <- epi_recipe(epi_data ) %> %
132133 step_growth_rate(
133- tidyselect :: all_of(predictors ),
134+ dplyr :: all_of(predictors ),
134135 role = " grp" ,
135136 horizon = args_list $ horizon ,
136137 method = args_list $ method ,
@@ -173,26 +174,24 @@ arx_class_epi_workflow <- function(
173174 o2 <- rlang :: sym(paste0(" ahead_" , args_list $ ahead , " _" , o ))
174175 r <- r %> %
175176 step_epi_ahead(!! o , ahead = args_list $ ahead , role = " pre-outcome" ) %> %
176- step_mutate(
177+ recipes :: step_mutate(
177178 outcome_class = cut(!! o2 , breaks = args_list $ breaks ),
178179 role = " outcome"
179180 ) %> %
180181 step_epi_naomit() %> %
181- step_training_window(n_recent = args_list $ n_training ) %> %
182- {
183- if (! is.null(args_list $ check_enough_data_n )) {
184- check_enough_train_data(
185- . ,
186- all_predictors(),
187- !! outcome ,
188- n = args_list $ check_enough_data_n ,
189- epi_keys = args_list $ check_enough_data_epi_keys ,
190- drop_na = FALSE
191- )
192- } else {
193- .
194- }
195- }
182+ step_training_window(n_recent = args_list $ n_training )
183+
184+ if (! is.null(args_list $ check_enough_data_n )) {
185+ r <- check_enough_train_data(
186+ r ,
187+ recipes :: all_predictors(),
188+ recipes :: all_outcomes(),
189+ n = args_list $ check_enough_data_n ,
190+ epi_keys = args_list $ check_enough_data_epi_keys ,
191+ drop_na = FALSE
192+ )
193+ }
194+
196195
197196 forecast_date <- args_list $ forecast_date %|| % max(epi_data $ time_value )
198197 target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
@@ -264,7 +263,7 @@ arx_class_args_list <- function(
264263 outcome_transform = c(" growth_rate" , " lag_difference" ),
265264 breaks = 0.25 ,
266265 horizon = 7L ,
267- method = c(" rel_change" , " linear_reg" , " smooth_spline " , " trend_filter " ),
266+ method = c(" rel_change" , " linear_reg" ),
268267 log_scale = FALSE ,
269268 additional_gr_args = list (),
270269 nafill_buffer = Inf ,
@@ -274,8 +273,8 @@ arx_class_args_list <- function(
274273 rlang :: check_dots_empty()
275274 .lags <- lags
276275 if (is.list(lags )) lags <- unlist(lags )
277- method <- match.arg (method )
278- outcome_transform <- match.arg (outcome_transform )
276+ method <- rlang :: arg_match (method )
277+ outcome_transform <- rlang :: arg_match (outcome_transform )
279278
280279 arg_is_scalar(ahead , n_training , horizon , log_scale )
281280 arg_is_scalar(forecast_date , target_date , allow_null = TRUE )
@@ -287,12 +286,11 @@ arx_class_args_list <- function(
287286 if (is.finite(n_training )) arg_is_pos_int(n_training )
288287 if (is.finite(nafill_buffer )) arg_is_pos_int(nafill_buffer , allow_null = TRUE )
289288 if (! is.list(additional_gr_args )) {
290- cli :: cli_abort(
291- c(" `additional_gr_args` must be a {.cls list}." ,
292- " !" = " This is a {.cls {class(additional_gr_args)}}." ,
293- i = " See `?epiprocess::growth_rate` for available arguments."
294- )
295- )
289+ cli_abort(c(
290+ " `additional_gr_args` must be a {.cls list}." ,
291+ " !" = " This is a {.cls {class(additional_gr_args)}}." ,
292+ i = " See `?epiprocess::growth_rate` for available arguments."
293+ ))
296294 }
297295 arg_is_pos(check_enough_data_n , allow_null = TRUE )
298296 arg_is_chr(check_enough_data_epi_keys , allow_null = TRUE )
0 commit comments