Skip to content

Commit 01fb08c

Browse files
authored
Merge branch 'main' into gam-775
2 parents 2daaa91 + 8880aff commit 01fb08c

File tree

128 files changed

+1012
-506
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+1012
-506
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ jobs:
6262
- name: Install Miniconda
6363
# conda can fail at downgrading python, so we specify python version in advance
6464
env:
65-
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
65+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.8"
6666
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
6767
shell: Rscript {0}
6868

6969
- name: Install TensorFlow
7070
run: |
71-
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
71+
tensorflow::install_tensorflow(version='2.13', conda_python_version = NULL)
7272
shell: Rscript {0}
7373

7474
- uses: r-lib/actions/check-r-package@v2

.github/workflows/pkgdown.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ jobs:
3636
- name: Install Miniconda
3737
# conda can fail at downgrading python, so we specify python version in advance
3838
env:
39-
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
39+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.8"
4040
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
4141
shell: Rscript {0}
4242

4343
- name: Install TensorFlow
4444
run: |
45-
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
45+
tensorflow::install_tensorflow(version='2.13', conda_python_version = NULL)
4646
shell: Rscript {0}
4747

4848
- name: Build site

.github/workflows/test-coverage.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ jobs:
3333
- name: Install Miniconda
3434
# conda can fail at downgrading python, so we specify python version in advance
3535
env:
36-
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
36+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.8"
3737
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
3838
shell: Rscript {0}
3939

4040
- name: Install TensorFlow
4141
run: |
42-
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
42+
tensorflow::install_tensorflow(version='2.13', conda_python_version = NULL)
4343
shell: Rscript {0}
4444

4545
- name: Test coverage

DESCRIPTION

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9000
3+
Version: 1.2.1.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
@@ -25,7 +25,7 @@ Imports:
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 1.1.0),
28+
hardhat (>= 1.3.1.9000),
2929
lifecycle,
3030
magrittr,
3131
pillar,
@@ -77,4 +77,6 @@ Config/testthat/edition: 3
7777
Encoding: UTF-8
7878
LazyData: true
7979
Roxygen: list(markdown = TRUE)
80+
Remotes:
81+
tidymodels/hardhat
8082
RoxygenNote: 7.3.1

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ S3method(check_args,svm_linear)
2929
S3method(check_args,svm_poly)
3030
S3method(check_args,svm_rbf)
3131
S3method(extract_fit_engine,model_fit)
32+
S3method(extract_fit_time,model_fit)
3233
S3method(extract_parameter_dials,model_spec)
3334
S3method(extract_parameter_set_dials,model_spec)
3435
S3method(extract_spec_parsnip,model_fit)
@@ -222,6 +223,7 @@ export(discrim_quad)
222223
export(discrim_regularized)
223224
export(eval_args)
224225
export(extract_fit_engine)
226+
export(extract_fit_time)
225227
export(extract_parameter_dials)
226228
export(extract_parameter_set_dials)
227229
export(extract_spec_parsnip)
@@ -376,6 +378,7 @@ importFrom(generics,varying_args)
376378
importFrom(ggplot2,autoplot)
377379
importFrom(glue,glue_collapse)
378380
importFrom(hardhat,extract_fit_engine)
381+
importFrom(hardhat,extract_fit_time)
379382
importFrom(hardhat,extract_parameter_dials)
380383
importFrom(hardhat,extract_parameter_set_dials)
381384
importFrom(hardhat,extract_spec_parsnip)

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# parsnip (development version)
22

3+
34
* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).
45

6+
* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
7+
8+
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
9+
10+
511
# parsnip 1.2.1
612

713
* Added a missing `tidy()` method for survival analysis glmnet models (#1086).

R/bag_tree.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ update.bag_tree <-
8585
# ------------------------------------------------------------------------------
8686

8787
#' @export
88-
check_args.bag_tree <- function(object) {
89-
if (object$engine == "C5.0" && object$mode == "regression")
90-
stop("C5.0 is classification only.", call. = FALSE)
88+
check_args.bag_tree <- function(object, call = rlang::caller_env()) {
9189
invisible(object)
9290
}
9391

R/boost_tree.R

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
164164
# ------------------------------------------------------------------------------
165165

166166
#' @export
167-
check_args.boost_tree <- function(object) {
167+
check_args.boost_tree <- function(object, call = rlang::caller_env()) {
168168

169169
args <- lapply(object$args, rlang::eval_tidy)
170170

171-
if (is.numeric(args$trees) && args$trees < 0) {
172-
rlang::abort("`trees` should be >= 1.")
173-
}
174-
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
175-
rlang::abort("`sample_size` should be within [0,1].")
176-
}
177-
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
178-
rlang::abort("`tree_depth` should be >= 1.")
179-
}
180-
if (is.numeric(args$min_n) && args$min_n < 0) {
181-
rlang::abort("`min_n` should be >= 1.")
182-
}
183-
171+
check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees")
172+
check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
173+
check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
174+
check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")
175+
184176
invisible(object)
185177
}
186178

R/c5_rules.R

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,23 @@ update.C5_rules <-
111111
# make work in different places
112112

113113
#' @export
114-
check_args.C5_rules <- function(object) {
114+
check_args.C5_rules <- function(object, call = rlang::caller_env()) {
115115

116116
args <- lapply(object$args, rlang::eval_tidy)
117117

118-
if (is.numeric(args$trees)) {
119-
if (length(args$trees) > 1) {
120-
rlang::abort("Only a single value of `trees` is used.")
121-
}
122-
msg <- "The number of trees should be >= 1 and <= 100. Truncating the value."
123-
if (args$trees > 100) {
124-
object$args$trees <-
125-
rlang::new_quosure(100L, env = rlang::empty_env())
126-
rlang::warn(msg)
127-
}
128-
if (args$trees < 1) {
129-
object$args$trees <-
130-
rlang::new_quosure(1L, env = rlang::empty_env())
131-
rlang::warn(msg)
132-
}
118+
check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n")
119+
check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree")
133120

121+
msg <- "The number of trees should be {.code >= 1} and {.code <= 100}"
122+
if (!(is.null(args$trees)) && args$trees > 100) {
123+
object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env())
124+
cli::cli_warn(c(msg, "Truncating to 100."))
134125
}
135-
if (is.numeric(args$min_n)) {
136-
if (length(args$min_n) > 1) {
137-
rlang::abort("Only a single `min_n`` value is used.")
138-
}
126+
if (!(is.null(args$trees)) && args$trees < 1) {
127+
object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env())
128+
cli::cli_warn(c(msg, "Truncating to 1."))
139129
}
130+
140131
invisible(object)
141132
}
142133

R/cubist_rules.R

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -135,44 +135,36 @@ update.cubist_rules <-
135135
# make work in different places
136136

137137
#' @export
138-
check_args.cubist_rules <- function(object) {
138+
check_args.cubist_rules <- function(object, call = rlang::caller_env()) {
139139

140140
args <- lapply(object$args, rlang::eval_tidy)
141141

142-
if (is.numeric(args$committees)) {
143-
if (length(args$committees) > 1) {
144-
rlang::abort("Only a single committee member is used.")
145-
}
146-
msg <- "The number of committees should be >= 1 and <= 100. Truncating the value."
147-
if (args$committees > 100) {
148-
object$args$committees <-
149-
rlang::new_quosure(100L, env = rlang::empty_env())
150-
rlang::warn(msg)
151-
}
152-
if (args$committees < 1) {
153-
object$args$committees <-
154-
rlang::new_quosure(1L, env = rlang::empty_env())
155-
rlang::warn(msg)
156-
}
142+
check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees")
157143

158-
}
159-
if (is.numeric(args$neighbors)) {
160-
if (length(args$neighbors) > 1) {
161-
rlang::abort("Only a single neighbors value is used.")
162-
}
163-
msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value."
164-
if (args$neighbors > 9) {
165-
object$args$neighbors <-
166-
rlang::new_quosure(9L, env = rlang::empty_env())
167-
rlang::warn(msg)
168-
}
169-
if (args$neighbors < 0) {
170-
object$args$neighbors <-
171-
rlang::new_quosure(0L, env = rlang::empty_env())
172-
rlang::warn(msg)
144+
msg <- "The number of committees should be {.code >= 1} and {.code <= 100}."
145+
if (!(is.null(args$committees)) && args$committees > 100) {
146+
object$args$committees <-
147+
rlang::new_quosure(100L, env = rlang::empty_env())
148+
cli::cli_warn(c(msg, "Truncating to 100."))
173149
}
150+
if (!(is.null(args$committees)) && args$committees < 1) {
151+
object$args$committees <-
152+
rlang::new_quosure(1L, env = rlang::empty_env())
153+
cli::cli_warn(c(msg, "Truncating to 1."))
154+
}
155+
156+
check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors")
174157

158+
msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}."
159+
if (!(is.null(args$neighbors)) && args$neighbors > 9) {
160+
object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env())
161+
cli::cli_warn(c(msg, "Truncating to 9."))
175162
}
163+
if (!(is.null(args$neighbors)) && args$neighbors < 0) {
164+
object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env())
165+
cli::cli_warn(c(msg, "Truncating to 0."))
166+
}
167+
176168
invisible(object)
177169
}
178170

R/decision_tree.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
128128
# ------------------------------------------------------------------------------
129129

130130
#' @export
131-
check_args.decision_tree <- function(object) {
132-
if (object$engine == "C5.0" && object$mode == "regression")
133-
rlang::abort("C5.0 is classification only.")
131+
check_args.decision_tree <- function(object, call = rlang::caller_env()) {
134132
invisible(object)
135133
}
136134

R/discrim_flexible.R

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,14 @@ update.discrim_flexible <-
8585
# ------------------------------------------------------------------------------
8686

8787
#' @export
88-
check_args.discrim_flexible <- function(object) {
88+
check_args.discrim_flexible <- function(object, call = rlang::caller_env()) {
8989

9090
args <- lapply(object$args, rlang::eval_tidy)
9191

92-
if (is.numeric(args$prod_degree) && args$prod_degree < 0)
93-
stop("`prod_degree` should be >= 1", call. = FALSE)
94-
95-
if (is.numeric(args$num_terms) && args$num_terms < 0)
96-
stop("`num_terms` should be >= 1", call. = FALSE)
97-
98-
if (!is.character(args$prune_method) &&
99-
!is.null(args$prune_method) &&
100-
!is.character(args$prune_method))
101-
stop("`prune_method` should be a single string value", call. = FALSE)
102-
92+
check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
93+
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
94+
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")
95+
10396
invisible(object)
10497
}
10598

R/discrim_linear.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ update.discrim_linear <-
8080
# ------------------------------------------------------------------------------
8181

8282
#' @export
83-
check_args.discrim_linear <- function(object) {
83+
check_args.discrim_linear <- function(object, call = rlang::caller_env()) {
8484

8585
args <- lapply(object$args, rlang::eval_tidy)
8686

87-
if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) {
88-
stop("The amount of regularization should be >= 0", call. = FALSE)
89-
}
87+
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
9088

9189
invisible(object)
9290
}

R/discrim_regularized.R

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,13 @@ update.discrim_regularized <-
9595
# ------------------------------------------------------------------------------
9696

9797
#' @export
98-
check_args.discrim_regularized <- function(object) {
98+
check_args.discrim_regularized <- function(object, call = rlang::caller_env()) {
9999

100100
args <- lapply(object$args, rlang::eval_tidy)
101101

102-
if (is.numeric(args$frac_common_cov) &&
103-
(args$frac_common_cov < 0 | args$frac_common_cov > 1)) {
104-
stop("The common covariance fraction should be between zero and one", call. = FALSE)
105-
}
106-
if (is.numeric(args$frac_identity) &&
107-
(args$frac_identity < 0 | args$frac_identity > 1)) {
108-
stop("The identity matrix fraction should be between zero and one", call. = FALSE)
109-
}
102+
check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov")
103+
check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity")
104+
110105
invisible(object)
111106
}
112107

R/extract.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@
1414
#'
1515
#' - `extract_parameter_set_dials()` returns a set of dials parameter objects.
1616
#'
17+
#' - `extract_fit_time()` returns a tibble with fit times. The fit times
18+
#' correspond to the time for the parsnip engine to fit and do not include
19+
#' other portions of the elapsed time in [parsnip::fit.model_spec()].
20+
#'
1721
#' @param x A parsnip `model_fit` object or a parsnip `model_spec` object.
1822
#' @param parameter A single string for the parameter ID.
23+
#' @param summarize A logical for whether the elapsed fit time should be
24+
#' returned as a single row or multiple rows. Doesn't support `FALSE` for
25+
#' parsnip models.
1926
#' @param ... Not currently used.
2027
#' @details
2128
#' Extracting the underlying engine fit can be helpful for describing the
@@ -127,3 +134,20 @@ eval_call_info <- function(x) {
127134
extract_parameter_dials.model_spec <- function(x, parameter, ...) {
128135
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
129136
}
137+
138+
#' @export
139+
#' @rdname extract-parsnip
140+
extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) {
141+
elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]]
142+
143+
if (is.na(elapsed) || is.null(elapsed)) {
144+
rlang::abort(
145+
"This model was fit before `extract_fit_time()` was added."
146+
)
147+
}
148+
149+
dplyr::tibble(
150+
stage_id = class(x$spec)[1],
151+
elapsed = elapsed
152+
)
153+
}

R/fit.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,15 @@ allow_sparse <- function(x) {
453453
#' @export
454454
print.model_fit <- function(x, ...) {
455455
cat("parsnip model object\n\n")
456-
if (!is.na(x$elapsed[["elapsed"]])) {
457-
cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n")
456+
457+
if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) {
458+
elapsed <- x$elapsed[["elapsed"]]
459+
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
460+
}
461+
462+
if (isTRUE(x$elapsed$print)) {
463+
elapsed <- x$elapsed$elapsed[["elapsed"]]
464+
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
458465
}
459466

460467
if (inherits(x$fit, "try-error")) {

0 commit comments

Comments
 (0)