Skip to content

Commit ad28e95

Browse files
surv.xgboost model type split + add distr predictions to surv.xgboost.cox (#333)
* add type init parameter * update tests, add one more * update docs * revert back to xgboost without type argument + small refactoring * add doc template for early stopping * refactor: convert function from task to xgboost data matrix * add xgboost Cox and AFT separate learners * fix roxygen warning * remove deprecated parameter * export new xgboost learners * small doc fix * doc: change early stopping position * update aorsf doc * revert back to old doc for surv.xgboost (objective-non-specific) * add doc for prediction types doc and refactor output prediction for xgboost AFT * more parmaeter tests * revert tests back to use original xgboost implementation * add new tests * add docs for the two types of xgboost learners * correct parameter name in aorsf * fix style warnings * more styling issues fixed * fix test (using rvest 1.0.4) * add comments * refactor xgboost importance function * add distr predictions to surv.xgboost.cox via Breslow * update xgboost tests * small fix * fix importance return value * doc update * add distr breslow test for surv.xgboost.cox * add note to old xgboost survival learner * clean up return type + add online doc for it as a comment * update NEWS.md * update docs * doc improvements * supress warnings for to-be-deprecated surv.xgboost learner * update: run document() across all learners * hardcode 'objective' and 'eval_metric' learner parameters and update tests --------- Co-authored-by: Sebastian Fischer <[email protected]>
1 parent 5e291e0 commit ad28e95

File tree

102 files changed

+1311
-318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+1311
-318
lines changed

NAMESPACE

+2
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ export(LearnerSurvRandomForestSRC)
125125
export(LearnerSurvRanger)
126126
export(LearnerSurvSVM)
127127
export(LearnerSurvXgboost)
128+
export(LearnerSurvXgboostAFT)
129+
export(LearnerSurvXgboostCox)
128130
export(create_learner)
129131
export(install_learners)
130132
export(list_mlr3learners)

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mlr3extralearners 0.7.1-9000
22

3+
* Added `surv.xgboost.cox` and `surv.xgboost.aft` separate survival learners. `distr` prediction on the cox xgboost learner is now estimated via Breslow by default and aft xgboost has now in addition a `response` prediction (survival time)
34
* Ported `surv.parametric` code to `survivalmodels`, changed `type` parameter to `form` to avoid conflict with survivalmodels's default parameter list
45
* Fix: Replace hardcoded `VectorDistribution`s from partykit and flexsurv survival learners with survival matrices (`Matdist`) (thanks to @bblodfon)
56
* Feat: Add `discrete` parameter in `surv.parametric` learner to return `Matdist` survival predictions

R/bibentries.R

+10-1
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,17 @@ bibentries = c( # nolint start
586586
month = "01",
587587
journal = "University of California, Berkeley"
588588
),
589+
barnwal2022 = bibentry("article",
590+
title = "Survival Regression with Accelerated Failure Time Model in XGBoost",
591+
author = "Barnwal Avinash, Cho Hyunsu and Hocking Toby",
592+
doi = "10.1080/10618600.2022.2067548",
593+
issn = "15372715",
594+
journal = "Journal of Computational and Graphical Statistics",
595+
publisher = "American Statistical Association",
596+
year = "2022"
597+
),
589598
Kohavi1995 = bibentry("inproceedings",
590-
author = "Ron Kohavi",
599+
author = "Ron Kohavi",
591600
booktitle = "8th European Conference on Machine Learning",
592601
pages = "174--189",
593602
publisher = "Springer",

R/helpers_xgboost.R

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# helper function to construct an `xgb.DMatrix` object
2+
# that has both features and target (label) data
3+
get_xgb_mat = function(task, objective, row_ids = NULL) {
4+
# use all task rows if `rows_ids` is not specified
5+
if (is.null(row_ids)) row_ids = task$row_ids
6+
7+
data = task$data(rows = row_ids, cols = task$feature_names)
8+
truth = task$truth(rows = row_ids)
9+
times = truth[, 1]
10+
status = truth[, 2]
11+
12+
if (objective == "survival:cox") { # Cox
13+
# censored => negative times, dead/event => positive times
14+
times[status != 1] = -1L * times[status != 1]
15+
data = xgboost::xgb.DMatrix(
16+
data = as_numeric_matrix(data),
17+
label = times
18+
)
19+
} else { # AFT
20+
y_lower_bound = y_upper_bound = times
21+
y_upper_bound[status == 0] = Inf
22+
23+
data = xgboost::xgb.DMatrix(as_numeric_matrix(data))
24+
xgboost::setinfo(data, "label_lower_bound", y_lower_bound)
25+
xgboost::setinfo(data, "label_upper_bound", y_upper_bound)
26+
}
27+
28+
data
29+
}
30+
31+
# return vector of importance scores given an `xgb.Booster` model
32+
xgb_imp = function(model) {
33+
if (is.null(model)) {
34+
stopf("No model stored")
35+
}
36+
37+
imp = xgboost::xgb.importance(model = model)
38+
set_names(imp$Gain, imp$Feature)
39+
}

R/learner_aorsf_surv_aorsf.R

+9-11
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,8 @@ LearnerSurvAorsf = R6Class("LearnerSurvAorsf",
4545
control_type = p_fct(levels = c("fast", "cph", "net"), default = "fast", tags = "train"),
4646
split_rule = p_fct(levels = c("logrank", "cstat"), default = "logrank", tags = "train"),
4747
control_fast_do_scale = p_lgl(default = FALSE, tags = "train"),
48-
control_fast_ties = p_fct(levels = c("efron", "breslow"),
49-
default = "efron", tags = "train"),
50-
control_cph_ties = p_fct(levels = c("efron", "breslow"),
51-
default = "efron", tags = "train"),
48+
control_fast_ties = p_fct(levels = c("efron", "breslow"), default = "efron", tags = "train"),
49+
control_cph_ties = p_fct(levels = c("efron", "breslow"), default = "efron", tags = "train"),
5250
control_cph_eps = p_dbl(default = 1e-9, lower = 0, tags = "train"),
5351
control_cph_iter_max = p_int(default = 20L, lower = 1, tags = "train"),
5452
control_net_alpha = p_dbl(default = 0.5, tags = "train"),
@@ -146,13 +144,13 @@ LearnerSurvAorsf = R6Class("LearnerSurvAorsf",
146144
# these parameters are used to organize the control arguments
147145
# above but are not used directly by aorsf::orsf(), so:
148146
pv = remove_named(pv, c("control_type",
149-
"control_fast_do_scale",
150-
"control_fast_ties",
151-
"control_cph_ties",
152-
"control_cph_eps",
153-
"control_cph_iter_max",
154-
"control_net_alpha",
155-
"control_net_df_target"))
147+
"control_fast_do_scale",
148+
"control_fast_ties",
149+
"control_cph_ties",
150+
"control_cph_eps",
151+
"control_cph_iter_max",
152+
"control_net_alpha",
153+
"control_net_df_target"))
156154
invoke(
157155
aorsf::orsf,
158156
data = task$data(),

R/learner_xgboost_surv_xgboost.R

+13-47
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,23 @@
66
#' eXtreme Gradient Boosting regression.
77
#' Calls [xgboost::xgb.train()] from package \CRANpkg{xgboost}.
88
#'
9+
#' **Note:** We strongly advise to use the separate [Cox][LearnerSurvXgboostCox]
10+
#' and [AFT][LearnerSurvXgboostAFT] xgboost survival learners since they represent
11+
#' two very distinct survival modeling methods and we offer more prediction
12+
#' types in the respective learners compared to the ones available here.
13+
#' This learner will be deprecated in the future.
14+
#'
915
#' @template note_xgboost
1016
#'
1117
#' @section Initial parameter values:
1218
#' - `nrounds` is initialized to 1.
1319
#' - `nthread` is initialized to 1 to avoid conflicts with parallelization via \CRANpkg{future}.
1420
#' - `verbose` is initialized to 0.
1521
#' - `objective` is initialized to `survival:cox` for survival analysis.
16-
#' @section Early stopping:
17-
#' Early stopping can be used to find the optimal number of boosting rounds.
18-
#' The `early_stopping_set` parameter controls which set is used to monitor the performance.
19-
#' Set `early_stopping_set = "test"` to monitor the performance of the model on the test set while training.
20-
#' The test set for early stopping can be set with the `"test"` row role in the [mlr3::Task].
21-
#' Additionally, the range must be set in which the performance must increase with `early_stopping_rounds` and the maximum number of boosting rounds with `nrounds`.
22-
#' While resampling, the test set is automatically applied from the [mlr3::Resampling].
23-
#' Not that using the test set for early stopping can potentially bias the performance scores.
2422
#'
2523
#' @templateVar id surv.xgboost
2624
#' @template learner
25+
#' @template section_early_stopping
2726
#'
2827
#' @references
2928
#' `r format_bib("chen_2016")`
@@ -37,6 +36,9 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
3736
#' @description
3837
#' Creates a new instance of this [R6][R6::R6Class] class.
3938
initialize = function() {
39+
.Deprecated(
40+
msg = "'surv.xgboost' will be deprecated in the future. Use 'surv.xgboost.cox' or 'surv.xgboost.aft' learners instead." #nolint
41+
)
4042

4143
ps = ps(
4244
aft_loss_distribution = p_fct(c("normal", "logistic", "extreme"), default = "normal", tags = "train"),
@@ -71,7 +73,6 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
7173
normalize_type = p_fct(c("tree", "forest"), default = "tree", tags = "train"),
7274
nrounds = p_int(1L, tags = "train"),
7375
nthread = p_int(1L, default = 1L, tags = c("train", "threads")),
74-
ntreelimit = p_int(1L, tags = "predict"),
7576
num_parallel_tree = p_int(1L, default = 1L, tags = "train"),
7677
objective = p_fct(c("survival:cox", "survival:aft"), default = "survival:cox", tags = c("train", "predict")),
7778
one_drop = p_lgl(default = FALSE, tags = "train"),
@@ -134,46 +135,11 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
134135
#'
135136
#' @return Named `numeric()`.
136137
importance = function() {
137-
if (is.null(self$model)) {
138-
stopf("No model stored")
139-
}
140-
141-
imp = xgboost::xgb.importance(
142-
model = self$model
143-
)
144-
set_names(imp$Gain, imp$Feature)
138+
xgb_imp(self$model)
145139
}
146140
),
147141

148142
private = list(
149-
# helper function to construct an `xgb.DMatrix` object
150-
.get_data = function(task, pv, row_ids = NULL) {
151-
# use all task rows if `rows_ids` is not specified
152-
if (is.null(row_ids))
153-
row_ids = task$row_ids
154-
155-
data = task$data(rows = row_ids, cols = task$feature_names)
156-
target = task$data(rows = row_ids, cols = task$target_names)
157-
targets = task$target_names
158-
label = target[[targets[1]]] # time
159-
status = target[[targets[2]]]
160-
161-
if (pv$objective == "survival:cox") {
162-
label[status != 1] = -1L * label[status != 1]
163-
data = xgboost::xgb.DMatrix(
164-
data = as_numeric_matrix(data),
165-
label = label)
166-
} else {
167-
y_lower_bound = y_upper_bound = label
168-
y_upper_bound[status == 0] = Inf
169-
170-
data = xgboost::xgb.DMatrix(as_numeric_matrix(data))
171-
xgboost::setinfo(data, "label_lower_bound", y_lower_bound)
172-
xgboost::setinfo(data, "label_upper_bound", y_upper_bound)
173-
}
174-
data
175-
},
176-
177143
.train = function(task) {
178144

179145
pv = self$param_set$get_values(tags = "train")
@@ -188,7 +154,7 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
188154
pv$eval_metric = "aft-nloglik"
189155
}
190156

191-
data = private$.get_data(task, pv)
157+
data = get_xgb_mat(task, pv$objective)
192158

193159
if ("weights" %in% task$properties) {
194160
xgboost::setinfo(data, "weight", task$weights$weight)
@@ -201,7 +167,7 @@ LearnerSurvXgboost = R6Class("LearnerSurvXgboost",
201167
}
202168

203169
if (pv$early_stopping_set == "test" && !is.null(task$row_roles$test)) {
204-
test_data = private$.get_data(task, pv, task$row_roles$test)
170+
test_data = get_xgb_mat(task, pv$objective, task$row_roles$test)
205171
pv$watchlist = c(pv$watchlist, list(test = test_data))
206172
}
207173
pv$early_stopping_set = NULL

0 commit comments

Comments
 (0)