Skip to content

Commit 7d0bd3c

Browse files
committed
fix tuning operation with torch models
1 parent 790b947 commit 7d0bd3c

9 files changed

+35
-31
lines changed

R/api_torch.R

+8-6
Original file line numberDiff line numberDiff line change
@@ -368,22 +368,24 @@
368368
.is_torch_model <- function(ml_model) {
369369
inherits(ml_model, "torch_model")
370370
}
371+
371372
.torch_has_cuda <- function(){
372373
torch::cuda_is_available()
373374
}
375+
374376
.torch_has_mps <- function(){
375377
torch::backends_mps_is_available()
376378
}
377379

378380
.torch_mem_info <- function() {
379-
if (.torch_has_cuda()){
380-
# Get memory summary
381+
mem_sum <- 0
382+
383+
if (.torch_has_cuda()) {
384+
# get current memory info in GB
381385
mem_sum <- torch::cuda_memory_stats()
382-
# Return current memory info in GB
383-
mem_sum[["allocated_bytes"]][["all"]][["current"]] / 10^9
384-
} else {
385-
mem_sum <- 0
386+
mem_sum <- mem_sum[["allocated_bytes"]][["all"]][["current"]] / 10^9
386387
}
388+
387389
return(mem_sum)
388390
}
389391
#' @title Verify if torch works on CUDA

R/sits_classify.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#' (integer, min = 1, max = 16384).
3636
#' @param multicores Number of cores to be used for classification
3737
#' (integer, min = 1, max = 2048).
38-
#' @param gpu_memory Memory available in GPU in GB (default = 16)
38+
#' @param gpu_memory Memory available in GPU in GB (default = 4)
3939
#' @param n_sam_pol Number of time series per segment to be classified
4040
#' (integer, min = 10, max = 50).
4141
#' @param output_dir Valid directory for output file.

R/sits_lighttae.R

+5-3
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,12 @@ sits_lighttae <- function(samples = NULL,
339339
values <- array(
340340
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
341341
)
342-
# Load into GPU
343-
if (.torch_has_cuda()){
342+
# Get GPU memory
343+
gpu_memory <- sits_env[["gpu_memory"]]
344+
# if CUDA is available and gpu memory is defined, transform values
345+
# to torch dataloader
346+
if (.torch_has_cuda() && .has(gpu_memory)) {
344347
# set the batch size according to the GPU memory
345-
gpu_memory <- sits_env[["gpu_memory"]]
346348
b_size <- 2^gpu_memory
347349
# transfor the input array to a dataset
348350
values <- .as_dataset(values)

R/sits_mlp.R

+5-4
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,12 @@ sits_mlp <- function(samples = NULL,
289289
values <- .pred_normalize(pred = values, stats = ml_stats)
290290
# Transform input into matrix
291291
values <- as.matrix(values)
292-
# if CUDA is available, transform to torch data set
293-
# Load into GPU
294-
if (.torch_has_cuda()){
292+
# Get GPU memory
293+
gpu_memory <- sits_env[["gpu_memory"]]
294+
# if CUDA is available and gpu memory is defined, transform values
295+
# to torch dataloader
296+
if (.torch_has_cuda() && .has(gpu_memory)) {
295297
# set the batch size according to the GPU memory
296-
gpu_memory <- sits_env[["gpu_memory"]]
297298
b_size <- 2^gpu_memory
298299
# transfor the input array to a dataset
299300
values <- .as_dataset(values)

R/sits_tae.R

+5-6
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,12 @@ sits_tae <- function(samples = NULL,
307307
values <- array(
308308
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
309309
)
310-
# Predict using GPU if available
311-
# If not, use CPU
312-
# if CUDA is available, transform to torch data set
313-
# Load into GPU
314-
if (.torch_has_cuda()){
310+
# Get GPU memory
311+
gpu_memory <- sits_env[["gpu_memory"]]
312+
# if CUDA is available and gpu memory is defined, transform values
313+
# to torch dataloader
314+
if (.torch_has_cuda() && .has(gpu_memory)) {
315315
# set the batch size according to the GPU memory
316-
gpu_memory <- sits_env[["gpu_memory"]]
317316
b_size <- 2^gpu_memory
318317
# transfor the input array to a dataset
319318
values <- .as_dataset(values)

R/sits_tempcnn.R

+6-5
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,16 @@ sits_tempcnn <- function(samples = NULL,
358358
values <- array(
359359
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
360360
)
361-
# if CUDA is available, transform to torch data set
362-
# Load into GPU
363-
if (.torch_has_cuda()){
361+
# Get GPU memory
362+
gpu_memory <- sits_env[["gpu_memory"]]
363+
# if CUDA is available and gpu memory is defined, transform values
364+
# to torch dataloader
365+
if (.torch_has_cuda() && .has(gpu_memory)) {
364366
# set the batch size according to the GPU memory
365-
gpu_memory <- sits_env[["gpu_memory"]]
366367
b_size <- 2^gpu_memory
367368
# transfor the input array to a dataset
368369
values <- .as_dataset(values)
369-
# To the data set to a torcj transform in a dataloader to use the batch size
370+
# To the data set to a torch transform in a dataloader to use the batch size
370371
values <- torch::dataloader(values, batch_size = b_size)
371372
# Do GPU classification with dataloader
372373
values <- .try(

R/sits_tuning.R

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
#' \code{ml_method}. User can use \code{uniform}, \code{choice},
3030
#' \code{randint}, \code{normal}, \code{lognormal}, \code{loguniform},
3131
#' and \code{beta} distribution functions to randomize parameters.
32-
#' @param trials Number of random trials to perform the random search.
33-
#' @param progress Show progress bar?
34-
#' @param multicores Number of cores to process in parallel
32+
#' @param trials Number of random trials to perform the random search.
33+
#' @param progress Show progress bar?
34+
#' @param multicores Number of cores to process in parallel.
3535
#'
3636
#' @return
3737
#' A tibble containing all parameters used to train on each trial
@@ -87,7 +87,6 @@ sits_tuning <- function(samples,
8787
# check validation_split parameter if samples_validation is not passed
8888
.check_num_parameter(validation_split, exclusive_min = 0, max = 0.5)
8989
}
90-
9190
# check 'ml_functions' parameter
9291
ml_function <- substitute(ml_method, env = environment())
9392
if (is.call(ml_function))

man/sits_classify.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sits_tuning.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)