Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,11 @@ xgb.cb.reset.parameters <- function(new_params) {
#' `metric_name = 'dtest-auc'` or `metric_name = 'dtest_auc'`.
#' All dash '-' characters in metric names are considered equivalent to '_'.
#' @param verbose Whether to print the early stopping information.
#'
#' @param keep_all_iter Whether to keep all of the boosting rounds that were produced
#' in the resulting object. If passing `FALSE`, will only keep the boosting rounds
#' up to the detected best iteration, discarding the ones that come after.
#' up to the detected best iteration, discarding the ones that come after. This
#' parameter is not supported by the `xgb.cv` function and the `gblinear` booster yet.
#' @return An `xgb.Callback` object, which can be passed to [xgb.train()] or [xgb.cv()].
#' @export
xgb.cb.early.stop <- function(
Expand Down Expand Up @@ -647,6 +649,9 @@ xgb.cb.early.stop <- function(
if (inherits(model, "xgb.Booster") && !length(evals)) {
stop("For early stopping, 'evals' must have at least one element")
}
if (!inherits(model, "xgb.Booster") && keep_all_iter) {
stop("`keep_all_iter` must be set to FALSE for cv.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop("`keep_all_iter` must be set to FALSE for cv.")
stop("'keep_all_iter' must be set to FALSE when using early stopping in 'xgb.cv'.")

}
env$begin_iteration <- begin_iteration
return(NULL)
},
Expand Down
7 changes: 6 additions & 1 deletion R-package/R/xgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,13 @@ xgb.cv <- function(params = xgb.params(), data, nrounds, nfold,
check.deprecation(deprecated_cv_params, match.call(), ...)

stopifnot(inherits(data, "xgb.DMatrix"))

if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}
if (inherits(data, "xgb.QuantileDMatrix")) {
stop("`xgb.QuantileDMatrix` is not yet supported for the cv function.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop("`xgb.QuantileDMatrix` is not yet supported for the cv function.")
stop("'xgb.QuantileDMatrix' is not supported as input to 'xgb.cv'.")

}

params <- check.booster.params(params)
# TODO: should we deprecate the redundant 'metrics' parameter?
Expand Down Expand Up @@ -171,7 +175,8 @@ xgb.cv <- function(params = xgb.params(), data, nrounds, nfold,
xgb.cb.early.stop(
early_stopping_rounds,
maximize = maximize,
verbose = verbose
verbose = verbose,
keep_all_iter = FALSE
),
as_first_elt = TRUE
)
Expand Down
3 changes: 2 additions & 1 deletion R-package/man/xgb.cb.early.stop.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,40 @@ test_that("xgb.cv works", {
expect_false(is.null(cv$call))
})

test_that("xgb.cv invalid inputs", {
data("mtcars")
y <- mtcars$mpg
x_df <- mtcars[, -1]

expect_error(
cv <- xgb.cv(
data = xgb.QuantileDMatrix(x_df, label = y),
nfold = 5,
nrounds = 2,
params = xgb.params(
max_depth = 2,
nthread = n_threads
)
),
regexp = ".*QuantileDMatrix.*"
)
expect_error(
cv <- xgb.cv(
data = xgb.DMatrix(x_df, label = y),
nfold = 5,
nrounds = 2,
params = xgb.params(
max_depth = 2,
nthread = n_threads,
),
callbacks = list(
xgb.cb.early.stop(stopping_rounds = 3)
)
),
regexp = ".*keep_all_iter.*"
)
})

test_that("xgb.cv works with stratified folds", {
dtrain <- xgb.DMatrix(train$data, label = train$label, nthread = n_threads)
set.seed(314159)
Expand Down
Loading