Skip to content

Commit f76961c

Browse files
authored
Merge pull request #374 from cmu-delphi/ds/quantreg-method
feat: expose "method" arg of quantile_reg
2 parents b4d4071 + 1040633 commit f76961c

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

R/make_quantile_reg.R

+12-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#' "rq" and "grf" are supported.
1414
#' @param quantile_levels A scalar or vector of values in (0, 1) to determine which
1515
#' quantiles to estimate (default is 0.5).
16+
#' @param method A fitting method used by [quantreg::rq()]. See the
17+
#' documentation for a list of options.
1618
#'
1719
#' @export
1820
#'
@@ -25,7 +27,7 @@
2527
#' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq")
2628
#' ff <- rq_spec %>% fit(y ~ ., data = tib)
2729
#' predict(ff, new_data = tib)
28-
quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5) {
30+
quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5, method = "br") {
2931
# Check for correct mode
3032
if (mode != "regression") {
3133
cli_abort("`mode` must be 'regression'")
@@ -38,7 +40,7 @@ quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0
3840
cli::cli_warn("Sorting `quantile_levels` to increasing order.")
3941
quantile_levels <- sort(quantile_levels)
4042
}
41-
args <- list(quantile_levels = rlang::enquo(quantile_levels))
43+
args <- list(quantile_levels = rlang::enquo(quantile_levels), method = rlang::enquo(method))
4244

4345
# Save some empty slots for future parts of the specification
4446
parsnip::new_model_spec(
@@ -57,9 +59,6 @@ make_quantile_reg <- function() {
5759
parsnip::set_new_model("quantile_reg")
5860
}
5961
parsnip::set_model_mode("quantile_reg", "regression")
60-
61-
62-
6362
parsnip::set_model_engine("quantile_reg", "regression", eng = "rq")
6463
parsnip::set_dependency("quantile_reg", eng = "rq", pkg = "quantreg")
6564

@@ -71,6 +70,14 @@ make_quantile_reg <- function() {
7170
func = list(pkg = "quantreg", fun = "rq"),
7271
has_submodel = FALSE
7372
)
73+
parsnip::set_model_arg(
74+
model = "quantile_reg",
75+
eng = "rq",
76+
parsnip = "method",
77+
original = "method",
78+
func = list(pkg = "quantreg", fun = "rq"),
79+
has_submodel = FALSE
80+
)
7481

7582
parsnip::set_fit(
7683
model = "quantile_reg",
@@ -81,7 +88,6 @@ make_quantile_reg <- function() {
8188
protect = c("formula", "data", "weights"),
8289
func = c(pkg = "quantreg", fun = "rq"),
8390
defaults = list(
84-
method = "br",
8591
na.action = rlang::expr(stats::na.omit),
8692
model = FALSE
8793
)
@@ -104,7 +110,6 @@ make_quantile_reg <- function() {
104110
object <- parsnip::extract_fit_engine(object)
105111
type <- class(object)[1]
106112

107-
108113
# can't make a method because object is second
109114
out <- switch(type,
110115
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
@@ -120,7 +125,6 @@ make_quantile_reg <- function() {
120125
return(dplyr::tibble(.pred = out))
121126
}
122127

123-
124128
parsnip::set_pred(
125129
model = "quantile_reg",
126130
eng = "rq",

man/quantile_reg.Rd

+9-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)