Skip to content

Commit 513aa15

Browse files
improve tuning to work with ML models
1 parent 02014a4 commit 513aa15

File tree

4 files changed

+67
-13
lines changed

4 files changed

+67
-13
lines changed

R/sits_machine_learning.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#' @param num_trees Number of trees to grow. This should not be set to too
1515
#' small a number, to ensure that every input
1616
#' row gets predicted at least a few times (default: 100)
17-
#' (integer, min = 50, max = 150).
17+
#' (integer, min = 20).
1818
#' @param mtry Number of variables randomly sampled as candidates at
1919
#' each split (default: NULL - use default value of
2020
#' \code{randomForest::randomForest()} function, i.e.

R/sits_tuning.R

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
#' hyperparameters for deep learning models.
1010
#'
1111
#' @note
12-
#' Machine learning models use stochastic gradient descent (SGD) techniques to
13-
#' find optimal solutions. To perform SGD, models use optimization
14-
#' algorithms which have hyperparameters that have to be adjusted
15-
#' to achieve best performance for each application.
16-
#
12+
#'
13+
#' Machine learning algorithms have hyperparameters that control
14+
#' the algorithm's behaviour. This function allows users to test
15+
#' different combinations of hyperparameters for a given sample set,
16+
#' thus selecting a set of values which fits the training data.
17+
#' The \code{sits_tuning} function can be used with both traditional
18+
#' machine learning methods (e.g., random forests) as weel as
19+
#' deep learning ones.
20+
#'
1721
#' Instead of performing an exhaustive test of all parameter combinations,
1822
#' \code{sits_tuning} selects them randomly.
1923
#' Validation is done using an independent set
@@ -22,6 +26,11 @@
2226
#' parameter should be passed by calling
2327
#' \code{\link[sits]{sits_tuning_hparams}}.
2428
#'
29+
#' Deep learning models use stochastic gradient descent (SGD) techniques to
30+
#' find optimal solutions. To perform SGD, models use optimization
31+
#' algorithms which have hyperparameters that have to be adjusted
32+
#' to achieve best performance for each application.
33+
#'
2534
#' When using a GPU for deep learning, \code{gpu_memory} indicates the
2635
#' memory of the graphics card which is available for processing.
2736
#' The parameter \code{batch_size} defines the size of the matrix
@@ -69,7 +78,7 @@
6978
#'
7079
#' @examples
7180
#' if (sits_run_examples()) {
72-
#' # find best learning rate parameters for TempCNN
81+
#' # find best learning rate for TempCNN
7382
#' tuned <- sits_tuning(
7483
#' samples_modis_ndvi,
7584
#' ml_method = sits_tempcnn(),
@@ -89,6 +98,22 @@
8998
#' accuracy <- tuned$accuracy[[1]]
9099
#' kappa <- tuned$kappa[[1]]
91100
#' best_lr <- tuned$opt_hparams[[1]]$lr
101+
#'.
102+
#' # find best number of trees for random foresr
103+
#' rf_tuned <- sits_tuning(
104+
#' samples_modis_ndvi,
105+
#' ml_method = sits_rfor(),
106+
#' params = sits_tuning_hparams(
107+
#' num_trees = choice(100, 200, 300)
108+
#' ),
109+
#' trials = 10,
110+
#' multicores = 2,
111+
#' progress = FALSE
112+
#' )
113+
#' # obtain best accuracy, kappa and best_lr
114+
#' rf_accuracy <- rf_tuned$accuracy[[1]]
115+
#' rf_kappa <- rf_tuned$kappa[[1]]
116+
#' rf_best_num_trees <- rf_tuned$num_trees
92117
#' }
93118
#'
94119
#' @export
@@ -130,7 +155,11 @@ sits_tuning <- function(samples,
130155
.check_that(!"samples" %in% names(params),
131156
msg = .conf("messages", "sits_tuning_samples")
132157
)
158+
# get the parameters with defaults
133159
params_default <- formals(ml_function)
160+
# remove dots from parameters
161+
params_default <- params_default[names(params_default) != "..."]
162+
# check parameters
134163
.check_chr_within(
135164
x = names(params),
136165
within = names(params_default)

man/sits_rfor.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sits_tuning.Rd

Lines changed: 30 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)