Skip to content

Commit 1eceeb6

Browse files
authored
Merge pull request #337 from cmu-delphi/ndefries/func-conversion-expansion
Refactor slide computation function generation and move to `as_slide_computation`
2 parents a9128e9 + 7be2e66 commit 1eceeb6

File tree

7 files changed

+192
-248
lines changed

7 files changed

+192
-248
lines changed

NAMESPACE

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,16 @@ importFrom(rlang,.env)
9696
importFrom(rlang,arg_match)
9797
importFrom(rlang,caller_arg)
9898
importFrom(rlang,caller_env)
99-
importFrom(rlang,check_dots_empty0)
10099
importFrom(rlang,enquo)
101100
importFrom(rlang,enquos)
102101
importFrom(rlang,env)
103102
importFrom(rlang,f_env)
104103
importFrom(rlang,f_rhs)
105-
importFrom(rlang,global_env)
106104
importFrom(rlang,is_environment)
107105
importFrom(rlang,is_formula)
108106
importFrom(rlang,is_function)
109107
importFrom(rlang,is_missing)
110108
importFrom(rlang,is_quosure)
111-
importFrom(rlang,is_string)
112109
importFrom(rlang,missing_arg)
113110
importFrom(rlang,new_function)
114111
importFrom(rlang,quo_is_missing)

R/grouped_epi_archive.R

Lines changed: 62 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ grouped_epi_archive =
186186
#' object. See the documentation for the wrapper function [`epix_slide()`] for
187187
#' details.
188188
#' @importFrom data.table key address
189-
#' @importFrom rlang !! !!! enquo quo_is_missing enquos is_quosure sym syms env
189+
#' @importFrom rlang !! !!! enquo quo_is_missing enquos is_quosure sym syms
190+
#' env missing_arg
190191
slide = function(f, ..., before, ref_time_values,
191192
time_step, new_col_name = "slide_value",
192193
as_list_col = FALSE, names_sep = "_",
@@ -229,11 +230,6 @@ grouped_epi_archive =
229230
# implementation doesn't take advantage of it.
230231
ref_time_values = sort(ref_time_values)
231232
}
232-
233-
# Check that `f` takes enough args
234-
if (!missing(f) && is.function(f)) {
235-
assert_sufficient_f_args(f, ...)
236-
}
237233

238234
# Validate and pre-process `before`:
239235
if (missing(before)) {
@@ -296,71 +292,8 @@ grouped_epi_archive =
296292
!!new_col := .env$comp_value))
297293
}
298294

299-
# If f is not missing, then just go ahead, slide by group
300-
if (!missing(f)) {
301-
if (rlang::is_formula(f)) f = as_slide_computation(f)
302-
x = purrr::map_dfr(ref_time_values, function(ref_time_value) {
303-
# Ungrouped as-of data; `epi_df` if `all_versions` is `FALSE`,
304-
# `epi_archive` if `all_versions` is `TRUE`:
305-
as_of_raw = private$ungrouped$as_of(ref_time_value, min_time_value = ref_time_value - before, all_versions = all_versions)
306-
307-
# Set:
308-
# * `as_of_df`, the data.frame/tibble/epi_df/etc. that we will
309-
# `group_modify` as the `.data` argument. Might or might not
310-
# include version column.
311-
# * `group_modify_fn`, the corresponding `.f` argument
312-
if (!all_versions) {
313-
as_of_df = as_of_raw
314-
group_modify_fn = comp_one_grp
315-
} else {
316-
as_of_archive = as_of_raw
317-
# We essentially want to `group_modify` the archive, but
318-
# haven't implemented this method yet. Next best would be
319-
# `group_modify` on its `$DT`, but that has different
320-
# behavior based on whether or not `dtplyr` is loaded.
321-
# Instead, go through an ordinary data frame, trying to avoid
322-
# copies.
323-
if (address(as_of_archive$DT) == address(private$ungrouped$DT)) {
324-
# `as_of` aliased its the full `$DT`; copy before mutating:
325-
as_of_archive$DT <- copy(as_of_archive$DT)
326-
}
327-
dt_key = data.table::key(as_of_archive$DT)
328-
as_of_df = as_of_archive$DT
329-
data.table::setDF(as_of_df)
330-
331-
# Convert each subgroup chunk to an archive before running the calculation.
332-
group_modify_fn = function(.data_group, .group_key,
333-
f, ...,
334-
ref_time_value,
335-
new_col) {
336-
# .data_group is coming from as_of_df as a tibble, but we
337-
# want to feed `comp_one_grp` an `epi_archive` backed by a
338-
# DT; convert and wrap:
339-
data.table::setattr(.data_group, "sorted", dt_key)
340-
data.table::setDT(.data_group, key=dt_key)
341-
.data_group_archive = as_of_archive$clone()
342-
.data_group_archive$DT = .data_group
343-
comp_one_grp(.data_group_archive, .group_key, f = f, ...,
344-
ref_time_value = ref_time_value,
345-
new_col = new_col
346-
)
347-
}
348-
}
349-
350-
return(
351-
dplyr::group_by(as_of_df, dplyr::across(tidyselect::all_of(private$vars)),
352-
.drop=private$drop) %>%
353-
dplyr::group_modify(group_modify_fn,
354-
f = f, ...,
355-
ref_time_value = ref_time_value,
356-
new_col = new_col,
357-
.keep = TRUE)
358-
)
359-
})
360-
}
361-
362-
# Else interpret ... as an expression for tidy evaluation
363-
else {
295+
# If `f` is missing, interpret ... as an expression for tidy evaluation
296+
if (missing(f)) {
364297
quos = enquos(...)
365298
if (length(quos) == 0) {
366299
Abort("If `f` is missing then a computation must be specified via `...`.")
@@ -369,83 +302,70 @@ grouped_epi_archive =
369302
Abort("If `f` is missing then only a single computation can be specified via `...`.")
370303
}
371304

372-
quo = quos[[1]]
373-
f = function(.x, .group_key, .ref_time_value, quo, ...) {
374-
# Convert to environment to standardize between tibble and R6
375-
# based inputs. In both cases, we should get a simple
376-
# environment with the empty environment as its parent.
377-
data_env = rlang::as_environment(.x)
378-
data_mask = rlang::new_data_mask(bottom = data_env, top = data_env)
379-
data_mask$.data <- rlang::as_data_pronoun(data_mask)
380-
# We'll also install `.x` directly, not as an
381-
# `rlang_data_pronoun`, so that we can, e.g., use more dplyr and
382-
# epiprocess operations.
383-
data_mask$.x = .x
384-
data_mask$.group_key = .group_key
385-
data_mask$.ref_time_value = .ref_time_value
386-
rlang::eval_tidy(quo, data_mask)
387-
}
305+
f = quos[[1]]
388306
new_col = sym(names(rlang::quos_auto_name(quos)))
307+
... = missing_arg() # magic value that passes zero args as dots in calls below
308+
}
389309

390-
x = purrr::map_dfr(ref_time_values, function(ref_time_value) {
391-
# Ungrouped as-of data; `epi_df` if `all_versions` is `FALSE`,
392-
# `epi_archive` if `all_versions` is `TRUE`:
393-
as_of_raw = private$ungrouped$as_of(ref_time_value, min_time_value = ref_time_value - before, all_versions = all_versions)
310+
f = as_slide_computation(f, ...)
311+
x = purrr::map_dfr(ref_time_values, function(ref_time_value) {
312+
# Ungrouped as-of data; `epi_df` if `all_versions` is `FALSE`,
313+
# `epi_archive` if `all_versions` is `TRUE`:
314+
as_of_raw = private$ungrouped$as_of(ref_time_value, min_time_value = ref_time_value - before, all_versions = all_versions)
394315

395-
# Set:
396-
# * `as_of_df`, the data.frame/tibble/epi_df/etc. that we will
397-
# `group_modify` as the `.data` argument. Might or might not
398-
# include version column.
399-
# * `group_modify_fn`, the corresponding `.f` argument
400-
if (!all_versions) {
401-
as_of_df = as_of_raw
402-
group_modify_fn = comp_one_grp
403-
} else {
404-
as_of_archive = as_of_raw
405-
# We essentially want to `group_modify` the archive, but don't
406-
# provide an implementation yet. Next best would be
407-
# `group_modify` on its `$DT`, but that has different behavior
408-
# based on whether or not `dtplyr` is loaded. Instead, go
409-
# through an ordinary data frame, trying to avoid copies.
410-
if (address(as_of_archive$DT) == address(private$ungrouped$DT)) {
411-
# `as_of` aliased its the full `$DT`; copy before mutating:
412-
as_of_archive$DT <- copy(as_of_archive$DT)
413-
}
414-
dt_key = data.table::key(as_of_archive$DT)
415-
as_of_df = as_of_archive$DT
416-
data.table::setDF(as_of_df)
316+
# Set:
317+
# * `as_of_df`, the data.frame/tibble/epi_df/etc. that we will
318+
# `group_modify` as the `.data` argument. Might or might not
319+
# include version column.
320+
# * `group_modify_fn`, the corresponding `.f` argument
321+
if (!all_versions) {
322+
as_of_df = as_of_raw
323+
group_modify_fn = comp_one_grp
324+
} else {
325+
as_of_archive = as_of_raw
326+
# We essentially want to `group_modify` the archive, but
327+
# haven't implemented this method yet. Next best would be
328+
# `group_modify` on its `$DT`, but that has different
329+
# behavior based on whether or not `dtplyr` is loaded.
330+
# Instead, go through an ordinary data frame, trying to avoid
331+
# copies.
332+
if (address(as_of_archive$DT) == address(private$ungrouped$DT)) {
333+
# `as_of` aliased its the full `$DT`; copy before mutating:
334+
as_of_archive$DT <- copy(as_of_archive$DT)
335+
}
336+
dt_key = data.table::key(as_of_archive$DT)
337+
as_of_df = as_of_archive$DT
338+
data.table::setDF(as_of_df)
417339

418-
# Convert each subgroup chunk to an archive before running the calculation.
419-
group_modify_fn = function(.data_group, .group_key,
420-
f, ...,
421-
ref_time_value,
422-
new_col) {
423-
# .data_group is coming from as_of_df as a tibble, but we
424-
# want to feed `comp_one_grp` an `epi_archive` backed by a
425-
# DT; convert and wrap:
426-
data.table::setattr(.data_group, "sorted", dt_key)
427-
data.table::setDT(.data_group, key=dt_key)
428-
.data_group_archive = as_of_archive$clone()
429-
.data_group_archive$DT = .data_group
430-
comp_one_grp(.data_group_archive, .group_key, f = f, quo = quo,
431-
ref_time_value = ref_time_value,
432-
new_col = new_col
433-
)
434-
}
340+
# Convert each subgroup chunk to an archive before running the calculation.
341+
group_modify_fn = function(.data_group, .group_key,
342+
f, ...,
343+
ref_time_value,
344+
new_col) {
345+
# .data_group is coming from as_of_df as a tibble, but we
346+
# want to feed `comp_one_grp` an `epi_archive` backed by a
347+
# DT; convert and wrap:
348+
data.table::setattr(.data_group, "sorted", dt_key)
349+
data.table::setDT(.data_group, key=dt_key)
350+
.data_group_archive = as_of_archive$clone()
351+
.data_group_archive$DT = .data_group
352+
comp_one_grp(.data_group_archive, .group_key, f = f, ...,
353+
ref_time_value = ref_time_value,
354+
new_col = new_col
355+
)
435356
}
357+
}
436358

437-
return(
438-
dplyr::group_by(as_of_df, dplyr::across(tidyselect::all_of(private$vars)),
439-
.drop=private$drop) %>%
440-
dplyr::group_modify(group_modify_fn,
441-
f = f, quo = quo,
442-
ref_time_value = ref_time_value,
443-
comp_effective_key_vars = comp_effective_key_vars,
444-
new_col = new_col,
445-
.keep = TRUE)
446-
)
447-
})
448-
}
359+
return(
360+
dplyr::group_by(as_of_df, dplyr::across(tidyselect::all_of(private$vars)),
361+
.drop=private$drop) %>%
362+
dplyr::group_modify(group_modify_fn,
363+
f = f, ...,
364+
ref_time_value = ref_time_value,
365+
new_col = new_col,
366+
.keep = TRUE)
367+
)
368+
})
449369

450370
# Unchop/unnest if we need to
451371
if (!as_list_col) {

R/slide.R

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
#'
124124
#' @importFrom lubridate days weeks
125125
#' @importFrom dplyr bind_rows group_vars filter select
126-
#' @importFrom rlang .data .env !! enquo enquos sym env
126+
#' @importFrom rlang .data .env !! enquo enquos sym env missing_arg
127127
#' @export
128128
#' @examples
129129
#' # slide a 7-day trailing average formula on cases
@@ -167,11 +167,6 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
167167

168168
# Check we have an `epi_df` object
169169
if (!inherits(x, "epi_df")) Abort("`x` must be of class `epi_df`.")
170-
171-
# Check that `f` takes enough args
172-
if (!missing(f) && is.function(f)) {
173-
assert_sufficient_f_args(f, ...)
174-
}
175170

176171
if (missing(ref_time_values)) {
177172
ref_time_values = unique(x$time_value)
@@ -356,28 +351,8 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
356351
return(mutate(.data_group, !!new_col := slide_values))
357352
}
358353

359-
# If f is not missing, then just go ahead, slide by group
360-
if (!missing(f)) {
361-
if (rlang::is_formula(f)) f = as_slide_computation(f)
362-
f_rtv_wrapper = function(x, g, ...) {
363-
ref_time_value = min(x$time_value) + before
364-
x <- x[x$.real,]
365-
x$.real <- NULL
366-
f(x, g, ref_time_value, ...)
367-
}
368-
x = x %>%
369-
group_modify(slide_one_grp,
370-
f = f_rtv_wrapper, ...,
371-
starts = starts,
372-
stops = stops,
373-
time_values = ref_time_values,
374-
all_rows = all_rows,
375-
new_col = new_col,
376-
.keep = FALSE)
377-
}
378-
379-
# Else interpret ... as an expression for tidy evaluation
380-
else {
354+
# If `f` is missing, interpret ... as an expression for tidy evaluation
355+
if (missing(f)) {
381356
quos = enquos(...)
382357
if (length(quos) == 0) {
383358
Abort("If `f` is missing then a computation must be specified via `...`.")
@@ -386,31 +361,29 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
386361
Abort("If `f` is missing then only a single computation can be specified via `...`.")
387362
}
388363

389-
quo = quos[[1]]
390-
f = function(.x, .group_key, quo, ...) {
391-
.ref_time_value = min(.x$time_value) + before
392-
.x <- .x[.x$.real,]
393-
.x$.real <- NULL
394-
data_mask = rlang::as_data_mask(.x)
395-
# We'll also install `.x` directly, not as an `rlang_data_pronoun`, so
396-
# that we can, e.g., use more dplyr and epiprocess operations.
397-
data_mask$.x = .x
398-
data_mask$.group_key = .group_key
399-
data_mask$.ref_time_value = .ref_time_value
400-
rlang::eval_tidy(quo, data_mask)
401-
}
364+
f = quos[[1]]
402365
new_col = sym(names(rlang::quos_auto_name(quos)))
403-
404-
x = x %>%
405-
group_modify(slide_one_grp,
406-
f = f, quo = quo,
407-
starts = starts,
408-
stops = stops,
409-
time_values = ref_time_values,
410-
all_rows = all_rows,
411-
new_col = new_col,
412-
.keep = FALSE)
366+
... = missing_arg() # magic value that passes zero args as dots in calls below
367+
}
368+
369+
f = as_slide_computation(f, ...)
370+
# Create a wrapper that calculates and passes `.ref_time_value` to the
371+
# computation.
372+
f_wrapper = function(.x, .group_key, ...) {
373+
.ref_time_value = min(.x$time_value) + before
374+
.x <- .x[.x$.real,]
375+
.x$.real <- NULL
376+
f(.x, .group_key, .ref_time_value, ...)
413377
}
378+
x = x %>%
379+
group_modify(slide_one_grp,
380+
f = f_wrapper, ...,
381+
starts = starts,
382+
stops = stops,
383+
time_values = ref_time_values,
384+
all_rows = all_rows,
385+
new_col = new_col,
386+
.keep = FALSE)
414387

415388
# Unnest if we need to, and return
416389
if (!as_list_col) {

0 commit comments

Comments
 (0)