Skip to content

Commit

Permalink
function to get prediction columns (#1224)
Browse files Browse the repository at this point in the history
* function to get prediction columns

* forgotten pkgdown entry

* also, bump version number

* fix for workflows

* Apply suggestions from code review

Co-authored-by: Emil Hvitfeldt <[email protected]>

---------

Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
topepo and EmilHvitfeldt authored Dec 8, 2024
1 parent a212f78 commit 27df158
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9003
Version: 1.2.1.9004
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ export(.dat)
export(.extract_surv_status)
export(.extract_surv_time)
export(.facts)
export(.get_prediction_column_names)
export(.lvls)
export(.model_param_name_key)
export(.obs)
Expand Down
72 changes: 72 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,75 @@ is_cran_check <- function() {
}
# nocov end

# ------------------------------------------------------------------------------

#' Obtain names of prediction columns for a fitted model or workflow
#'
#' [.get_prediction_column_names()] returns a list that has the names of the
#' columns for the primary prediction types for a model.
#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow.
#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`.
#' @return A list with elements `"estimate"` and `"probabilities"`.
#' @examplesIf !parsnip:::is_cran_check()
#' library(dplyr)
#' library(modeldata)
#' data("two_class_dat")
#'
#' levels(two_class_dat$Class)
#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
#'
#' .get_prediction_column_names(lr_fit)
#' .get_prediction_column_names(lr_fit, syms = TRUE)
#' @export
.get_prediction_column_names <- function(x, syms = FALSE) {
if (!inherits(x, c("model_fit", "workflow"))) {
cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or
{.cls workflow}, not {.obj_type_friendly {x}}.")
}

if (inherits(x, "workflow")) {
x <- x %>% extract_fit_parsnip(x)
}
model_spec <- extract_spec_parsnip(x)
model_engine <- model_spec$engine
model_mode <- model_spec$mode
model_type <- class(model_spec)[1]

# appropriate populate the model db
inst_res <- purrr::map(required_pkgs(x), rlang::check_installed)
predict_types <-
get_from_env(paste0(model_type, "_predict")) %>%
dplyr::filter(engine == model_engine & mode == model_mode) %>%
purrr::pluck("type")

if (length(predict_types) == 0) {
cli::cli_abort("Prediction information could not be found for this
{.fn {model_type}} with engine {.val {model_engine}} and mode
{.val {model_mode}}. Does a parsnip extension package need to
be loaded?")
}

res <- list(estimate = character(0), probabilities = character(0))

if (model_mode == "regression") {
res$estimate <- ".pred"
} else if (model_mode == "classification") {
res$estimate <- ".pred_class"
if (any(predict_types == "prob")) {
res$probabilities <- paste0(".pred_", x$lvl)
}
} else if (model_mode == "censored regression") {
res$estimate <- ".pred_time"
if (any(predict_types %in% c("survival"))) {
res$probabilities <- ".pred"
}
} else {
# Should be unreachable
cli::cli_abort("Unsupported model mode {model_mode}.")
}

if (syms) {
res <- purrr::map(res, rlang::syms)
}
res
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ reference:
- .extract_surv_status
- .extract_surv_time
- .model_param_name_key
- .get_prediction_column_names
33 changes: 33 additions & 0 deletions man/dot-get_prediction_column_names.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,19 @@
Error in `check_outcome()`:
! For a censored regression model, the outcome should be a <Surv> object, not an integer vector.

# obtaining prediction columns

Code
.get_prediction_column_names(1)
Condition
Error in `.get_prediction_column_names()`:
! `x` should be an object with class <model_fit> or <workflow>, not a number.

---

Code
.get_prediction_column_names(unk_fit)
Condition
Error in `.get_prediction_column_names()`:
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?

50 changes: 50 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,53 @@ test_that('check_outcome works as expected', {
check_outcome(1:2, cens_spec)
)
})

# ------------------------------------------------------------------------------

test_that('obtaining prediction columns', {
skip_if_not_installed("modeldata")
data(two_class_dat, package = "modeldata")

### classification
lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
expect_equal(
.get_prediction_column_names(lr_fit),
list(estimate = ".pred_class",
probabilities = c(".pred_Class1", ".pred_Class2"))
)
expect_equal(
.get_prediction_column_names(lr_fit, syms = TRUE),
list(estimate = list(quote(.pred_class)),
probabilities = list(quote(.pred_Class1), quote(.pred_Class2)))
)

### regression
ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)
expect_equal(
.get_prediction_column_names(ols_fit),
list(estimate = ".pred",
probabilities = character(0))
)
expect_equal(
.get_prediction_column_names(ols_fit, syms = TRUE),
list(estimate = list(quote(.pred)),
probabilities = list())
)

### censored regression
# in extratests

### bad input
expect_snapshot(
.get_prediction_column_names(1),
error = TRUE
)

unk_fit <- ols_fit
unk_fit$spec$mode <- "Depeche"
expect_snapshot(
.get_prediction_column_names(unk_fit),
error = TRUE
)

})

0 comments on commit 27df158

Please sign in to comment.