Skip to content

Commit a1a53c5

Browse files
authored
Merge pull request #302 from cmu-delphi/ndefries/useful-slide-arg-errors
Check that the `f` passed to `epi[x]_slide` takes enough args
2 parents 6b8fad8 + a872728 commit a1a53c5

File tree

8 files changed

+197
-4
lines changed

8 files changed

+197
-4
lines changed

DESCRIPTION

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ Description: This package introduces a common data structure for epidemiological
2121
work with revisions to these data sets over time, and offers associated
2222
utilities to perform basic signal processing tasks.
2323
License: MIT + file LICENSE
24-
Imports:
24+
Imports:
25+
cli,
2526
data.table,
2627
dplyr (>= 1.0.0),
2728
fabletools,
@@ -48,7 +49,7 @@ Suggests:
4849
knitr,
4950
outbreaks,
5051
rmarkdown,
51-
testthat (>= 3.0.0),
52+
testthat (>= 3.1.5),
5253
waldo (>= 0.3.1),
5354
withr
5455
VignetteBuilder:

NAMESPACE

+3
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ importFrom(dplyr,ungroup)
8484
importFrom(lubridate,days)
8585
importFrom(lubridate,weeks)
8686
importFrom(magrittr,"%>%")
87+
importFrom(purrr,map_lgl)
8788
importFrom(rlang,"!!!")
8889
importFrom(rlang,"!!")
8990
importFrom(rlang,.data)
9091
importFrom(rlang,.env)
9192
importFrom(rlang,arg_match)
9293
importFrom(rlang,enquo)
9394
importFrom(rlang,enquos)
95+
importFrom(rlang,is_missing)
9496
importFrom(rlang,is_quosure)
9597
importFrom(rlang,quo_is_missing)
9698
importFrom(rlang,sym)
@@ -101,3 +103,4 @@ importFrom(tidyr,unnest)
101103
importFrom(tidyselect,eval_select)
102104
importFrom(tidyselect,starts_with)
103105
importFrom(tsibble,as_tsibble)
106+
importFrom(utils,tail)

R/grouped_epi_archive.R

+5
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ grouped_epi_archive =
220220
ref_time_values = sort(ref_time_values)
221221
}
222222

223+
# Check that `f` takes enough args
224+
if (!missing(f) && is.function(f)) {
225+
assert_sufficient_f_args(f, ...)
226+
}
227+
223228
# Validate and pre-process `before`:
224229
if (missing(before)) {
225230
Abort("`before` is required (and must be passed by name);

R/slide.R

+6-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,12 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
155155

156156
# Check we have an `epi_df` object
157157
if (!inherits(x, "epi_df")) Abort("`x` must be of class `epi_df`.")
158-
158+
159+
# Check that `f` takes enough args
160+
if (!missing(f) && is.function(f)) {
161+
assert_sufficient_f_args(f, ...)
162+
}
163+
159164
# Arrange by increasing time_value
160165
x = arrange(x, time_value)
161166

R/utils.R

+81
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,87 @@ paste_lines = function(lines) {
100100
Abort = function(msg, ...) rlang::abort(break_str(msg, init = "Error: "), ...)
101101
Warn = function(msg, ...) rlang::warn(break_str(msg, init = "Warning: "), ...)
102102

103+
#' Assert that a sliding computation function takes enough args
104+
#'
105+
#' @param f Function; specifies a computation to slide over an `epi_df` or
106+
#' `epi_archive` in `epi_slide` or `epix_slide`.
107+
#' @param ... Dots that will be forwarded to `f` from the dots of `epi_slide` or
108+
#' `epix_slide`.
109+
#'
110+
#' @importFrom rlang is_missing
111+
#' @importFrom purrr map_lgl
112+
#' @importFrom utils tail
113+
#'
114+
#' @noRd
115+
assert_sufficient_f_args <- function(f, ...) {
116+
mandatory_f_args_labels <- c("window data", "group key")
117+
n_mandatory_f_args <- length(mandatory_f_args_labels)
118+
args = formals(args(f))
119+
args_names = names(args)
120+
# Remove named arguments forwarded from `epi[x]_slide`'s `...`:
121+
forwarded_dots_names = names(rlang::call_match(dots_expand = FALSE)[["..."]])
122+
args_matched_in_dots =
123+
# positional calling args will skip over args matched by named calling args
124+
args_names %in% forwarded_dots_names &
125+
# extreme edge case: `epi[x]_slide(<stuff>, dot = 1, `...` = 2)`
126+
args_names != "..."
127+
remaining_args = args[!args_matched_in_dots]
128+
remaining_args_names = names(remaining_args)
129+
# note that this doesn't include unnamed args forwarded through `...`.
130+
dots_i <- which(remaining_args_names == "...") # integer(0) if no match
131+
n_f_args_before_dots <- dots_i - 1L
132+
if (length(dots_i) != 0L) { # `f` has a dots "arg"
133+
# Keep all arg names before `...`
134+
mandatory_args_mapped_names <- remaining_args_names[seq_len(n_f_args_before_dots)]
135+
136+
if (n_f_args_before_dots < n_mandatory_f_args) {
137+
mandatory_f_args_in_f_dots =
138+
tail(mandatory_f_args_labels, n_mandatory_f_args - n_f_args_before_dots)
139+
cli::cli_warn(
140+
"`f` might not have enough positional arguments before its `...`; in the current `epi[x]_slide` call, the {mandatory_f_args_in_f_dots} will be included in `f`'s `...`; if `f` doesn't expect those arguments, it may produce confusing error messages",
141+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots",
142+
epiprocess__f = f,
143+
epiprocess__mandatory_f_args_in_f_dots = mandatory_f_args_in_f_dots
144+
)
145+
}
146+
} else { # `f` doesn't have a dots "arg"
147+
if (length(args_names) < n_mandatory_f_args + rlang::dots_n(...)) {
148+
# `f` doesn't take enough args.
149+
if (rlang::dots_n(...) == 0L) {
150+
# common case; try for friendlier error message
151+
Abort(sprintf("`f` must take at least %s arguments", n_mandatory_f_args),
152+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args",
153+
epiprocess__f = f)
154+
} else {
155+
# less common; highlight that they are (accidentally?) using dots forwarding
156+
Abort(sprintf("`f` must take at least %s arguments plus the %s arguments forwarded through `epi[x]_slide`'s `...`, or a named argument to `epi[x]_slide` was misspelled", n_mandatory_f_args, rlang::dots_n(...)),
157+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded",
158+
epiprocess__f = f)
159+
}
160+
}
161+
}
162+
# Check for args with defaults that are filled with mandatory positional
163+
# calling args. If `f` has fewer than n_mandatory_f_args before `...`, then we
164+
# only need to check those args for defaults. Note that `n_f_args_before_dots` is
165+
# length 0 if `f` doesn't accept `...`.
166+
n_remaining_args_for_default_check = min(c(n_f_args_before_dots, n_mandatory_f_args))
167+
default_check_args = remaining_args[seq_len(n_remaining_args_for_default_check)]
168+
default_check_args_names = names(default_check_args)
169+
has_default_replaced_by_mandatory = map_lgl(default_check_args, ~!is_missing(.x))
170+
if (any(has_default_replaced_by_mandatory)) {
171+
default_check_mandatory_args_labels =
172+
mandatory_f_args_labels[seq_len(n_remaining_args_for_default_check)]
173+
# ^ excludes any mandatory args absorbed by f's `...`'s:
174+
mandatory_args_replacing_defaults =
175+
default_check_mandatory_args_labels[has_default_replaced_by_mandatory]
176+
args_with_default_replaced_by_mandatory =
177+
rlang::syms(default_check_args_names[has_default_replaced_by_mandatory])
178+
cli::cli_abort("`epi[x]_slide` would pass the {mandatory_args_replacing_defaults} to `f`'s {args_with_default_replaced_by_mandatory} argument{?s}, which {?has a/have} default value{?s}; we suspect that `f` doesn't expect {?this arg/these args} at all and may produce confusing error messages. Please add additional arguments to `f` or remove defaults as appropriate.",
179+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults",
180+
epiprocess__f = f)
181+
}
182+
}
183+
103184
##########
104185

105186
in_range = function(x, rng) pmin(pmax(x, rng[1]), rng[2])

tests/testthat/test-epi_slide.R

+11
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,14 @@ test_that("these doesn't produce an error; the error appears only if the ref tim
8686
dplyr::select("geo_value","slide_value_value"),
8787
dplyr::tibble(geo_value = c("ak", "al"), slide_value_value = c(2, -2))) # not out of range for either group
8888
})
89+
90+
test_that("epi_slide alerts if the provided f doesn't take enough args", {
91+
f_xg = function(x, g) dplyr::tibble(value=mean(x$value), count=length(x$value))
92+
# If `regexp` is NA, asserts that there should be no errors/messages.
93+
expect_error(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA)
94+
expect_warning(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA)
95+
96+
f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$value), count=length(x$value))
97+
expect_warning(epi_slide(grouped, f_x_dots, before = 1L, ref_time_values = d+1),
98+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
99+
})

tests/testthat/test-epix_slide.R

+11
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,14 @@ test_that("epix_slide with all_versions option works as intended",{
348348

349349
expect_identical(xx1,xx3) # This and * Imply xx2 and xx3 are identical
350350
})
351+
352+
test_that("epix_slide alerts if the provided f doesn't take enough args", {
353+
f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
354+
# If `regexp` is NA, asserts that there should be no errors/messages.
355+
expect_error(epix_slide(xx, f = f_xg, before = 2L), regexp = NA)
356+
expect_warning(epix_slide(xx, f = f_xg, before = 2L), regexp = NA)
357+
358+
f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
359+
expect_warning(epix_slide(xx, f_x_dots, before = 2L),
360+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
361+
})

tests/testthat/test-utils.R

+77-1
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,80 @@ test_that("enlist works",{
107107
my_list <- enlist(x=1,y=2,z=3)
108108
expect_equal(my_list$x,1)
109109
expect_true(inherits(my_list,"list"))
110-
})
110+
})
111+
112+
test_that("assert_sufficient_f_args alerts if the provided f doesn't take enough args", {
113+
f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
114+
f_xg_dots = function(x, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
115+
116+
# If `regexp` is NA, asserts that there should be no errors/messages.
117+
expect_error(assert_sufficient_f_args(f_xg), regexp = NA)
118+
expect_warning(assert_sufficient_f_args(f_xg), regexp = NA)
119+
expect_error(assert_sufficient_f_args(f_xg_dots), regexp = NA)
120+
expect_warning(assert_sufficient_f_args(f_xg_dots), regexp = NA)
121+
122+
f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
123+
f_dots = function(...) dplyr::tibble(value=c(5), count=c(2))
124+
f_x = function(x) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
125+
f = function() dplyr::tibble(value=c(5), count=c(2))
126+
127+
expect_warning(assert_sufficient_f_args(f_x_dots),
128+
regexp = ", the group key will be included",
129+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
130+
expect_warning(assert_sufficient_f_args(f_dots),
131+
regexp = ", the window data and group key will be included",
132+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
133+
expect_error(assert_sufficient_f_args(f_x),
134+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args")
135+
expect_error(assert_sufficient_f_args(f),
136+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args")
137+
138+
f_xs_dots = function(x, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
139+
f_xs = function(x, setting="a") dplyr::tibble(value=mean(x$binary), count=length(x$binary))
140+
expect_warning(assert_sufficient_f_args(f_xs_dots, setting="b"),
141+
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
142+
expect_error(assert_sufficient_f_args(f_xs, setting="b"),
143+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded")
144+
145+
expect_error(assert_sufficient_f_args(f_xg, "b"),
146+
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded")
147+
})
148+
149+
test_that("assert_sufficient_f_args alerts if the provided f has defaults for the required args", {
150+
f_xg = function(x, g=1) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
151+
f_xg_dots = function(x=1, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
152+
f_x_dots = function(x=1, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
153+
154+
expect_error(assert_sufficient_f_args(f_xg),
155+
regexp = "pass the group key to `f`'s g argument,",
156+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
157+
expect_error(assert_sufficient_f_args(f_xg_dots),
158+
regexp = "pass the window data to `f`'s x argument,",
159+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
160+
expect_error(suppressWarnings(assert_sufficient_f_args(f_x_dots)),
161+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
162+
163+
f_xsg = function(x, setting="a", g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
164+
f_xsg_dots = function(x, setting="a", g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
165+
f_xs_dots = function(x=1, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
166+
167+
# forwarding named dots should prevent some complaints:
168+
expect_no_error(assert_sufficient_f_args(f_xsg, setting = "b"))
169+
expect_no_error(assert_sufficient_f_args(f_xsg_dots, setting = "b"))
170+
expect_error(suppressWarnings(assert_sufficient_f_args(f_xs_dots, setting = "b")),
171+
regexp = "window data to `f`'s x argument",
172+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
173+
174+
# forwarding unnamed dots should not:
175+
expect_error(assert_sufficient_f_args(f_xsg, "b"),
176+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
177+
expect_error(assert_sufficient_f_args(f_xsg_dots, "b"),
178+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
179+
expect_error(assert_sufficient_f_args(f_xs_dots, "b"),
180+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
181+
182+
# forwarding no dots should produce a different error message in some cases:
183+
expect_error(assert_sufficient_f_args(f_xs_dots),
184+
regexp = "window data and group key to `f`'s x and setting argument",
185+
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
186+
})

0 commit comments

Comments
 (0)