Skip to content

Commit 6a24743

Browse files
dshemetovdsweber2
authored andcommitted
fix: review tweaks
* vectorize in lambda * inheritParams in docs * lambda -> yj_param in many places
1 parent 95c50b5 commit 6a24743

File tree

7 files changed

+191
-172
lines changed

7 files changed

+191
-172
lines changed

R/layer_yeo_johnson.R

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22
#'
33
#' Will undo a step_epi_YeoJohnson transformation.
44
#'
5-
#' @param frosting a `frosting` postprocessor. The layer will be added to the
6-
#' sequence of operations for this frosting.
7-
#' @param lambdas Internal. A data frame of lambda values to be used for
5+
#' @inheritParams layer_population_scaling
6+
#' @param yj_params Internal. A data frame of parameters to be used for
87
#' inverting the transformation.
9-
#' @param ... One or more selector functions to scale variables
10-
#' for this step. See [recipes::selections()] for more details.
118
#' @param by A (possibly named) character vector of variables to join by.
12-
#' @param id a random id string
139
#'
1410
#' @return an updated `frosting` postprocessor
1511
#' @export
@@ -41,39 +37,41 @@
4137
#' # Compare to the original data.
4238
#' jhu %>% filter(time_value == "2021-12-31")
4339
#' forecast(wf)
44-
layer_epi_YeoJohnson <- function(frosting, ..., lambdas = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
45-
checkmate::assert_tibble(lambdas, min.rows = 1, null.ok = TRUE)
40+
layer_epi_YeoJohnson <- function(frosting, ..., yj_params = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
41+
checkmate::assert_tibble(yj_params, min.rows = 1, null.ok = TRUE)
4642

4743
add_layer(
4844
frosting,
4945
layer_epi_YeoJohnson_new(
50-
lambdas = lambdas,
46+
yj_params = yj_params,
5147
by = by,
5248
terms = dplyr::enquos(...),
5349
id = id
5450
)
5551
)
5652
}
5753

58-
layer_epi_YeoJohnson_new <- function(lambdas, by, terms, id) {
59-
layer("epi_YeoJohnson", lambdas = lambdas, by = by, terms = terms, id = id)
54+
layer_epi_YeoJohnson_new <- function(yj_params, by, terms, id) {
55+
layer("epi_YeoJohnson", yj_params = yj_params, by = by, terms = terms, id = id)
6056
}
6157

6258
#' @export
6359
#' @importFrom workflows extract_preprocessor
6460
slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data, ...) {
6561
rlang::check_dots_empty()
6662

67-
# Get the lambdas from the layer or from the workflow.
68-
lambdas <- object$lambdas %||% get_lambdas_in_layer(workflow)
63+
# TODO: We will error if we don't have a workflow. Write a check later.
6964

70-
# If the by is not specified, try to infer it from the lambdas.
65+
# Get the yj_params from the layer or from the workflow.
66+
yj_params <- object$yj_params %||% get_yj_params_in_layer(workflow)
67+
68+
# If the by is not specified, try to infer it from the yj_params.
7169
if (is.null(object$by)) {
7270
# Assume `layer_predict` has calculated the prediction keys and other
7371
# layers don't change the prediction key colnames:
7472
prediction_key_colnames <- names(components$keys)
7573
lhs_potential_keys <- prediction_key_colnames
76-
rhs_potential_keys <- colnames(select(lambdas, -starts_with("lambda_")))
74+
rhs_potential_keys <- colnames(select(yj_params, -starts_with(".yj_param_")))
7775
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
7876
suggested_min_keys <- setdiff(lhs_potential_keys, "time_value")
7977
if (!all(suggested_min_keys %in% object$by)) {
@@ -95,16 +93,16 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
9593
object$by <- object$by %||%
9694
intersect(
9795
epi_keys_only(components$predictions),
98-
colnames(select(lambdas, -starts_with(".lambda_")))
96+
colnames(select(yj_params, -starts_with(".yj_param_")))
9997
)
10098
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
10199
hardhat::validate_column_names(components$predictions, joinby$x)
102-
hardhat::validate_column_names(lambdas, joinby$y)
100+
hardhat::validate_column_names(yj_params, joinby$y)
103101

104-
# Join the lambdas.
102+
# Join the yj_params.
105103
components$predictions <- inner_join(
106104
components$predictions,
107-
lambdas,
105+
yj_params,
108106
by = object$by,
109107
relationship = "many-to-one",
110108
unmatched = c("error", "drop")
@@ -115,7 +113,7 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
115113
col_names <- names(pos)
116114

117115
# The `object$terms` is where the user specifies the columns they want to
118-
# untransform. We need to match the outcomes with their lambda columns in our
116+
# untransform. We need to match the outcomes with their yj_param columns in our
119117
# parameter table and then apply the inverse transformation.
120118
if (identical(col_names, ".pred")) {
121119
# In this case, we don't get a hint for the outcome column name, so we need
@@ -130,8 +128,7 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
130128
magrittr::extract(, 2)
131129

132130
components$predictions <- components$predictions %>%
133-
rowwise() %>%
134-
mutate(.pred := yj_inverse(.pred, !!sym(paste0(".lambda_", outcome_cols))))
131+
mutate(.pred := yj_inverse(.pred, !!sym(paste0(".yj_param_", outcome_cols))))
135132
} else if (identical(col_names, character(0))) {
136133
# Wish I could suggest `all_outcomes()` here, but currently it's the same as
137134
# not specifying any terms. I don't want to spend time with dealing with
@@ -146,10 +143,10 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
146143
)
147144
} else {
148145
# In this case, we assume that the user has specified the columns they want
149-
# transformed here. We then need to determine the lambda columns for each of
146+
# transformed here. We then need to determine the yj_param columns for each of
150147
# these columns. That is, we need to convert a vector of column names like
151148
# c(".pred_ahead_1_case_rate", ".pred_ahead_7_case_rate") to
152-
# c("lambda_ahead_1_case_rate", "lambda_ahead_7_case_rate").
149+
# c(".yj_param_ahead_1_case_rate", ".yj_param_ahead_7_case_rate").
153150
original_outcome_cols <- stringr::str_match(col_names, ".pred_ahead_\\d+_(.*)")[, 2]
154151
outcomes_wout_ahead <- stringr::str_match(names(components$mold$outcomes), "ahead_\\d+_(.*)")[, 2]
155152
if (any(original_outcome_cols %nin% outcomes_wout_ahead)) {
@@ -163,34 +160,37 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
163160

164161
for (i in seq_along(col_names)) {
165162
col <- col_names[i]
166-
lambda_col <- paste0(".lambda_", original_outcome_cols[i])
163+
yj_param_col <- paste0(".yj_param_", original_outcome_cols[i])
167164
components$predictions <- components$predictions %>%
168-
rowwise() %>%
169-
mutate(!!sym(col) := yj_inverse(!!sym(col), !!sym(lambda_col)))
165+
mutate(!!sym(col) := yj_inverse(!!sym(col), !!sym(yj_param_col)))
170166
}
171167
}
172168

173-
# Remove the lambda columns.
169+
# Remove the yj_param columns.
174170
components$predictions <- components$predictions %>%
175-
select(-any_of(starts_with(".lambda_"))) %>%
171+
select(-any_of(starts_with(".yj_param_"))) %>%
176172
ungroup()
177173
components
178174
}
179175

180176
#' @export
181177
print.layer_epi_YeoJohnson <- function(x, width = max(20, options()$width - 30), ...) {
182-
title <- "Yeo-Johnson transformation (see `lambdas` object for values) on "
178+
title <- "Yeo-Johnson transformation (see `yj_params` object for values) on "
183179
print_layer(x$terms, title = title, width = width)
184180
}
185181

186182
# Inverse Yeo-Johnson transformation
187183
#
188-
# Inverse of `yj_transform` in step_yeo_johnson.R. Note that this function is
189-
# vectorized in x, but not in lambda.
184+
# Inverse of `yj_transform` in step_yeo_johnson.R.
190185
yj_inverse <- function(x, lambda, eps = 0.001) {
191-
if (is.na(lambda)) {
186+
if (any(is.na(lambda))) {
192187
return(x)
193188
}
189+
if (length(x) > 1 && length(lambda) == 1) {
190+
lambda <- rep(lambda, length(x))
191+
} else if (length(x) != length(lambda)) {
192+
cli::cli_abort("Length of `x` must be equal to length of `lambda`.", call = rlang::caller_fn())
193+
}
194194
if (!inherits(x, "tbl_df") || is.data.frame(x)) {
195195
x <- unlist(x, use.names = FALSE)
196196
} else {
@@ -199,52 +199,58 @@ yj_inverse <- function(x, lambda, eps = 0.001) {
199199
}
200200
}
201201

202-
dat_neg <- x < 0
203-
ind_neg <- list(is = which(dat_neg), not = which(!dat_neg))
204-
not_neg <- ind_neg[["not"]]
205-
is_neg <- ind_neg[["is"]]
206-
207202
nn_inv_trans <- function(x, lambda) {
208203
out <- double(length(x))
209204
sm_lambdas <- abs(lambda) < eps
210-
out[sm_lambdas] <- exp(x[sm_lambdas]) - 1
205+
if (length(sm_lambdas) > 0) {
206+
out[sm_lambdas] <- exp(x[sm_lambdas]) - 1
207+
}
211208
x <- x[!sm_lambdas]
212209
lambda <- lambda[!sm_lambdas]
213-
out[!sm_lambdas] <- (lambda * x + 1)^(1 / lambda) - 1
210+
if (length(x) > 0) {
211+
out[!sm_lambdas] <- (lambda * x + 1)^(1 / lambda) - 1
212+
}
214213
out
215214
}
216-
}
217215

218216
ng_inv_trans <- function(x, lambda) {
219-
if (abs(lambda - 2) < eps) {
220-
# -log(-x + 1)
221-
-(exp(-x) - 1)
222-
} else {
223-
# -((-x + 1)^(2 - lambda) - 1) / (2 - lambda)
224-
-(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
217+
out <- double(length(x))
218+
near2_lambdas <- abs(lambda - 2) < eps
219+
if (length(near2_lambdas) > 0) {
220+
out[near2_lambdas] <- -(exp(-x[near2_lambdas]) - 1)
221+
}
222+
x <- x[!near2_lambdas]
223+
lambda <- lambda[!near2_lambdas]
224+
if (length(x) > 0) {
225+
out[!near2_lambdas] <- -(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
225226
}
227+
out
226228
}
227229

230+
dat_neg <- x < 0
231+
not_neg <- which(!dat_neg)
232+
is_neg <- which(dat_neg)
233+
228234
if (length(not_neg) > 0) {
229-
x[not_neg] <- nn_inv_trans(x[not_neg], lambda)
235+
x[not_neg] <- nn_inv_trans(x[not_neg], lambda[not_neg])
230236
}
231237

232238
if (length(is_neg) > 0) {
233-
x[is_neg] <- ng_inv_trans(x[is_neg], lambda)
239+
x[is_neg] <- ng_inv_trans(x[is_neg], lambda[is_neg])
234240
}
235241
x
236242
}
237243

238-
get_lambdas_in_layer <- function(workflow) {
244+
get_yj_params_in_layer <- function(workflow) {
239245
this_recipe <- hardhat::extract_recipe(workflow)
240246
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
241247
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
242248
}
243249
for (step in this_recipe$steps) {
244250
if (inherits(step, "step_epi_YeoJohnson")) {
245-
lambdas <- step$lambdas
251+
yj_params <- step$yj_params
246252
break
247253
}
248254
}
249-
lambdas
255+
yj_params
250256
}

R/step_adjust_latency.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ step_adjust_latency_new <-
272272
# lags introduces max(lags) NA's after the max_time_value.
273273
#' @export
274274
#' @importFrom glue glue
275-
#' @importFrom dplyr rowwise
276275
prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
277276
latency <- x$latency
278277
col_names <- recipes::recipes_eval_select(x$terms, training, info)

0 commit comments

Comments
 (0)