Skip to content

Commit 3b2572c

Browse files
topepohfrick
andauthored
bug fix for non-package models (#1230)
* push back a month * redoc * changes for #1229 * small cleanup * update news * snapshot * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> --------- Co-authored-by: Hannah Frick <[email protected]>
1 parent 8adb5cc commit 3b2572c

File tree

5 files changed

+67
-9
lines changed

5 files changed

+67
-9
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
3535

36+
* Fixed bug related to using local (non-package) models (#1229)
37+
3638
* `tunable()` now references a dials object for the `mixture` parameter (#1236)
3739

3840
## Breaking Change

R/misc.R

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,15 @@ prompt_missing_implementation <- function(spec,
241241
#' @keywords internal
242242
#' @export
243243
show_call <- function(object) {
244-
object$method$fit$args <-
245-
map(object$method$fit$args, convert_arg)
244+
object$method$fit$args <- map(object$method$fit$args, convert_arg)
246245

247-
call2(object$method$fit$func["fun"],
248-
!!!object$method$fit$args,
249-
.ns = object$method$fit$func["pkg"]
250-
)
246+
fn_info <- as.list(object$method$fit$func)
247+
if (!any(names(fn_info) == "pkg")) {
248+
res <- call2(fn_info$fun, !!!object$method$fit$args)
249+
} else {
250+
res <- call2(fn_info$fun, !!!object$method$fit$args, .ns = fn_info$pkg)
251+
}
252+
res
251253
}
252254

253255
convert_arg <- function(x) {
@@ -301,8 +303,8 @@ check_args.default <- function(object, call = rlang::caller_env()) {
301303

302304
# ------------------------------------------------------------------------------
303305

304-
# copied form recipes
305-
306+
# copied from recipes
307+
# nocov start
306308
names0 <- function(num, prefix = "x", call = rlang::caller_env()) {
307309
if (num < 1) {
308310
cli::cli_abort("{.arg num} should be > 0.", call = call)
@@ -311,7 +313,7 @@ names0 <- function(num, prefix = "x", call = rlang::caller_env()) {
311313
ind <- gsub(" ", "0", ind)
312314
paste0(prefix, ind)
313315
}
314-
316+
# nocov end
315317

316318
# ------------------------------------------------------------------------------
317319

parsnip.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Version: 1.0
22
ProjectId: 7f6c9ff5-6b9a-4235-8666-12db5ef65d49
33

4+
45
RestoreWorkspace: No
56
SaveWorkspace: No
67
AlwaysSaveHistory: Default

tests/testthat/_snaps/misc.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,15 @@
243243
Error in `.get_prediction_column_names()`:
244244
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?
245245

246+
# register local models
247+
248+
Code
249+
my_model() %>% translate("my_engine")
250+
Output
251+
my model Model Specification (regression)
252+
253+
Computational engine: my_engine
254+
255+
Model fit template:
256+
my_model_fun(formula = missing_arg(), data = missing_arg())
257+

tests/testthat/test-misc.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,44 @@ test_that('obtaining prediction columns', {
299299
)
300300

301301
})
302+
303+
304+
# ------------------------------------------------------------------------------
305+
306+
# https://github.com/tidymodels/parsnip/issues/1229
307+
test_that('register local models', {
308+
set_new_model("my_model")
309+
set_model_mode(model = "my_model", mode = "regression")
310+
set_model_engine(
311+
"my_model",
312+
mode = "regression",
313+
eng = "my_engine"
314+
)
315+
316+
my_model <-
317+
function(mode = "regression") {
318+
new_model_spec(
319+
"my_model",
320+
args = list(),
321+
eng_args = NULL,
322+
mode = mode,
323+
method = NULL,
324+
engine = NULL
325+
)
326+
}
327+
328+
set_fit(
329+
model = "my_model",
330+
eng = "my_engine",
331+
mode = "regression",
332+
value = list(
333+
interface = "matrix",
334+
protect = c("formula", "data"),
335+
func = c(fun = "my_model_fun"),
336+
defaults = list()
337+
)
338+
)
339+
340+
expect_snapshot(my_model() %>% translate("my_engine"))
341+
})
342+

0 commit comments

Comments
 (0)