Skip to content

Commit

Permalink
bug fix for non-package models (#1230)
Browse files Browse the repository at this point in the history
* push back a month

* redoc

* changes for #1229

* small cleanup

* update news

* snapshot

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

---------

Co-authored-by: Hannah Frick <[email protected]>
  • Loading branch information
topepo and hfrick authored Jan 29, 2025
1 parent 8adb5cc commit 3b2572c
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 9 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).

* Fixed bug related to using local (non-package) models (#1229)

* `tunable()` now references a dials object for the `mixture` parameter (#1236)

## Breaking Change
Expand Down
20 changes: 11 additions & 9 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,15 @@ prompt_missing_implementation <- function(spec,
#' @keywords internal
#' @export
show_call <- function(object) {
object$method$fit$args <-
map(object$method$fit$args, convert_arg)
object$method$fit$args <- map(object$method$fit$args, convert_arg)

call2(object$method$fit$func["fun"],
!!!object$method$fit$args,
.ns = object$method$fit$func["pkg"]
)
fn_info <- as.list(object$method$fit$func)
if (!any(names(fn_info) == "pkg")) {
res <- call2(fn_info$fun, !!!object$method$fit$args)
} else {
res <- call2(fn_info$fun, !!!object$method$fit$args, .ns = fn_info$pkg)
}
res
}

convert_arg <- function(x) {
Expand Down Expand Up @@ -301,8 +303,8 @@ check_args.default <- function(object, call = rlang::caller_env()) {

# ------------------------------------------------------------------------------

# copied form recipes

# copied from recipes
# nocov start
names0 <- function(num, prefix = "x", call = rlang::caller_env()) {
if (num < 1) {
cli::cli_abort("{.arg num} should be > 0.", call = call)
Expand All @@ -311,7 +313,7 @@ names0 <- function(num, prefix = "x", call = rlang::caller_env()) {
ind <- gsub(" ", "0", ind)
paste0(prefix, ind)
}

# nocov end

# ------------------------------------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions parsnip.Rproj
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Version: 1.0
ProjectId: 7f6c9ff5-6b9a-4235-8666-12db5ef65d49


RestoreWorkspace: No
SaveWorkspace: No
AlwaysSaveHistory: Default
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,15 @@
Error in `.get_prediction_column_names()`:
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?

# register local models

Code
my_model() %>% translate("my_engine")
Output
my model Model Specification (regression)
Computational engine: my_engine
Model fit template:
my_model_fun(formula = missing_arg(), data = missing_arg())

41 changes: 41 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,44 @@ test_that('obtaining prediction columns', {
)

})


# ------------------------------------------------------------------------------

# https://github.com/tidymodels/parsnip/issues/1229
test_that('register local models', {
set_new_model("my_model")
set_model_mode(model = "my_model", mode = "regression")
set_model_engine(
"my_model",
mode = "regression",
eng = "my_engine"
)

my_model <-
function(mode = "regression") {
new_model_spec(
"my_model",
args = list(),
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
}

set_fit(
model = "my_model",
eng = "my_engine",
mode = "regression",
value = list(
interface = "matrix",
protect = c("formula", "data"),
func = c(fun = "my_model_fun"),
defaults = list()
)
)

expect_snapshot(my_model() %>% translate("my_engine"))
})

0 comments on commit 3b2572c

Please sign in to comment.