Skip to content

Commit b10edc6

Browse files
improve api_torch and disable MPS
1 parent e6e1526 commit b10edc6

11 files changed

+112
-105
lines changed

R/api_classify.R

+5-11
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,8 @@
195195
start_time = tile_start_time,
196196
verbose = verbose
197197
)
198-
# Clean torch allocations
199-
if (.is_torch_model(ml_model) && .torch_has_cuda()) {
200-
torch::cuda_empty_cache()
201-
}
198+
# Clean GPU memory allocation
199+
.ml_gpu_clean(ml_model)
202200
# Return probs tile
203201
probs_tile
204202
}
@@ -527,9 +525,7 @@
527525
)
528526
}
529527
# choose between GPU and CPU
530-
if (inherits(ml_model, "torch_model") &&
531-
(.torch_has_cuda() || .torch_has_mps())
532-
)
528+
if (.torch_gpu_enabled(ml_model))
533529
prediction <- .classify_ts_gpu(
534530
pred = pred,
535531
ml_model = ml_model,
@@ -630,10 +626,8 @@
630626
values <- ml_model(values)
631627
# Return classification
632628
values <- tibble::tibble(data.frame(values))
633-
# Clean torch cache
634-
if (.is_torch_model(ml_model) && .torch_has_cuda()) {
635-
torch::cuda_empty_cache()
636-
}
629+
# Clean GPU memory
630+
.ml_gpu_clean(ml_model)
637631
return(values)
638632
})
639633
return(prediction)

R/api_ml_model.R

+12
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,15 @@
9393
names(labels) <- seq_along(labels)
9494
labels
9595
}
96+
#' @title Clean GPU memory allocation
97+
#' @keywords internal
98+
#' @noRd
99+
#' @param ml_model Closure that contains ML model and its environment
100+
#' @return Called for side effects
101+
.ml_gpu_clean <- function(ml_model) {
102+
# Clean torch allocations
103+
if (.is_torch_model(ml_model) && .torch_has_cuda()) {
104+
torch::cuda_empty_cache()
105+
}
106+
return(invisible(NULL))
107+
}

R/api_torch.R

+71-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@
372372
torch::cuda_is_available()
373373
}
374374
.torch_has_mps <- function(){
375-
torch::backends_mps_is_available()
375+
# torch::backends_mps_is_available()
376+
return(FALSE)
376377
}
377378

378379
.torch_mem_info <- function() {
@@ -386,6 +387,75 @@
386387
}
387388
return(mem_sum)
388389
}
390+
#' @title Verify if torch works on GPU
391+
#' @name .torch_gpu_enabled
392+
#' @author Gilberto Camara, \email{gilberto.camara@@inpe.br}
393+
#' @keywords internal
394+
#' @noRd
395+
#' @description Use CPU or GPU for torch models depending on
396+
#' availability
397+
#'
398+
#' @param ml_model ML model
399+
#'
400+
#' @return TRUE/FALSE
401+
#'
402+
.torch_gpu_enabled <- function(ml_model){
403+
gpu_enabled <- (inherits(ml_model, "torch_model") &&
404+
(.torch_has_cuda() || .torch_has_mps())
405+
)
406+
return(gpu_enabled)
407+
}
408+
#' @title Torch function for prediction in GPU or CPU
409+
#' @name .torch_predict
410+
#' @author Gilberto Camara, \email{gilberto.camara@@inpe.br}
411+
#' @keywords internal
412+
#' @noRd
413+
#' @description Use CPU or GPU for torch models depending on
414+
#' availability
415+
#'
416+
#' @param values Values to be predicted
417+
#' @param torch_model Torch model
418+
#'
419+
#' @return Predicted values
420+
#'
421+
.torch_predict <- function(values, torch_model){
422+
# if CUDA is available, transform to torch data set
423+
# Load into GPU
424+
if (.torch_has_cuda() || .torch_has_mps()) {
425+
values <- .as_dataset(values)
426+
# We need to transform in a dataloader to use the batch size
427+
values <- torch::dataloader(
428+
values, batch_size = 2^15
429+
)
430+
# Do GPU classification
431+
values <- .try(
432+
stats::predict(object = torch_model, values),
433+
.msg_error = .conf("messages", ".check_gpu_memory_size")
434+
)
435+
} else {
436+
# Do CPU classification
437+
values <- stats::predict(object = torch_model, values)
438+
}
439+
return(values)
440+
}
441+
#' @title Use GPU or CPU train for MPS Apple
442+
#' @name .torch_mps_train
443+
#' @author Gilberto Camara, \email{gilberto.camara@@inpe.br}
444+
#' @keywords internal
445+
#' @noRd
446+
#' @description Use CPU or GPU for torch models depending on
447+
#' availability
448+
#'
449+
#' @return TRUE/FALSE
450+
#'
451+
.torch_mps_train <- function() {
452+
if (torch::backends_mps_is_available())
453+
cpu_train <- TRUE
454+
else
455+
cpu_train <- FALSE
456+
457+
return(cpu_train)
458+
}
389459

390460
.as_dataset <- torch::dataset(
391461
"dataset",

R/sits_classify.R

+3-5
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ sits_classify.raster_cube <- function(data,
228228
# Get default proc bloat
229229
proc_bloat <- .conf("processing_bloat_cpu")
230230
# If we using the GPU, gpu_memory parameter needs to be specified
231-
if (.is_torch_model(ml_model) &&
232-
(.torch_has_cuda() || .torch_has_mps())) {
231+
if (.torch_gpu_enabled(ml_model)) {
233232
.check_int_parameter(gpu_memory, min = 1, max = 16384,
234233
msg = .conf("messages", ".check_gpu_memory")
235234
)
@@ -356,7 +355,7 @@ sits_classify.segs_cube <- function(data,
356355
end_date = NULL,
357356
memsize = 8L,
358357
multicores = 2L,
359-
gpu_memory = 16,
358+
gpu_memory = 4,
360359
output_dir,
361360
version = "v1",
362361
n_sam_pol = NULL,
@@ -377,8 +376,7 @@ sits_classify.segs_cube <- function(data,
377376
.check_progress(progress)
378377
proc_bloat <- .conf("processing_bloat_seg_class")
379378
# If we using the GPU, gpu_memory parameter needs to be specified
380-
if (.is_torch_model(ml_model) &&
381-
(.torch_has_cuda() || .torch_has_mps())) {
379+
if (.torch_gpu_enabled(ml_model)) {
382380
.check_int_parameter(gpu_memory, min = 1, max = 16384,
383381
msg = .conf("messages", ".check_gpu_memory")
384382
)

R/sits_lighttae.R

+4-21
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,7 @@ sits_lighttae <- function(samples = NULL,
271271
}
272272
)
273273
# torch 12.0 not working with Apple MPS
274-
if (torch::backends_mps_is_available())
275-
cpu_train <- TRUE
276-
else
277-
cpu_train <- FALSE
274+
cpu_train <- .torch_mps_train()
278275
# Train the model using luz
279276
torch_model <-
280277
luz::setup(
@@ -342,23 +339,9 @@ sits_lighttae <- function(samples = NULL,
342339
values <- array(
343340
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
344341
)
345-
# if CUDA is available, transform to torch data set
346-
# Load into GPU
347-
if (.torch_has_cuda() || .torch_has_mps()) {
348-
values <- .as_dataset(values)
349-
# We need to transform in a dataloader to use the batch size
350-
values <- torch::dataloader(
351-
values, batch_size = 2^15
352-
)
353-
# Do GPU classification
354-
values <- .try(
355-
stats::predict(object = torch_model, values),
356-
.msg_error = .conf("messages", ".check_gpu_memory_size")
357-
)
358-
} else {
359-
# Do CPU classification
360-
values <- stats::predict(object = torch_model, values)
361-
}
342+
# Predict using GPU if available
343+
# If not, use CPU
344+
values <- .torch_predict(values, torch_model)
362345
# Convert to tensor CPU
363346
values <- torch::as_array(
364347
x = torch::torch_tensor(values, device = "cpu")

R/sits_mlp.R

+5-22
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
#' @examples
5959
#' if (sits_run_examples()) {
6060
#' # create an MLP model
61-
#' torch_model <- sits_train(samples_modis_ndvi, sits_mlp())
61+
#' torch_model <- sits_train(samples_modis_ndvi,
62+
#' sits_mlp(epochs = 20, verbose = TRUE))
6263
#' # plot the model
6364
#' plot(torch_model)
6465
#' # create a data cube from local files
@@ -241,10 +242,7 @@ sits_mlp <- function(samples = NULL,
241242
}
242243
)
243244
# torch 12.0 not working with Apple MPS
244-
if (torch::backends_mps_is_available())
245-
cpu_train <- TRUE
246-
else
247-
cpu_train <- FALSE
245+
cpu_train <- .torch_mps_train()
248246
# Train the model using luz
249247
torch_model <-
250248
luz::setup(
@@ -292,23 +290,8 @@ sits_mlp <- function(samples = NULL,
292290
values <- .pred_normalize(pred = values, stats = ml_stats)
293291
# Transform input into matrix
294292
values <- as.matrix(values)
295-
# if CUDA is available, transform to torch data set
296-
# Load into GPU
297-
if (.torch_has_cuda() || .torch_has_mps()) {
298-
values <- .as_dataset(values)
299-
# We need to transform in a dataloader to use the batch size
300-
values <- torch::dataloader(
301-
values, batch_size = 2^15
302-
)
303-
# Do GPU classification
304-
values <- .try(
305-
stats::predict(object = torch_model, values),
306-
.msg_error = .conf("messages", ".check_gpu_memory_size")
307-
)
308-
} else {
309-
# Do CPU classification
310-
values <- stats::predict(object = torch_model, values)
311-
}
293+
# predict using CPU or GPU depending on machine
294+
values <- .torch_predict(values, torch_model)
312295
# Convert to tensor cpu to support GPU processing
313296
values <- torch::as_array(
314297
x = torch::torch_tensor(values, device = "cpu")

R/sits_tae.R

+4-21
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,7 @@ sits_tae <- function(samples = NULL,
244244
}
245245
)
246246
# torch 12.0 not working with Apple MPS
247-
if (.torch_has_mps())
248-
cpu_train <- TRUE
249-
else
250-
cpu_train <- FALSE
247+
cpu_train <- .torch_mps_train()
251248
# train the model using luz
252249
torch_model <-
253250
luz::setup(
@@ -309,23 +306,9 @@ sits_tae <- function(samples = NULL,
309306
values <- array(
310307
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
311308
)
312-
# if CUDA is available, transform to torch data set
313-
# Load into GPU
314-
if (.torch_has_cuda() || .torch_has_mps()) {
315-
values <- .as_dataset(values)
316-
# We need to transform in a dataloader to use the batch size
317-
values <- torch::dataloader(
318-
values, batch_size = 2^15
319-
)
320-
# Do GPU classification
321-
values <- .try(
322-
stats::predict(object = torch_model, values),
323-
.msg_error = .conf("messages", ".check_gpu_memory_size")
324-
)
325-
} else {
326-
# Do CPU classification
327-
values <- stats::predict(object = torch_model, values)
328-
}
309+
# Predict using GPU if available
310+
# If not, use CPU
311+
values <- .torch_predict(values, torch_model)
329312
# Convert to tensor CPU
330313
values <- torch::as_array(
331314
x = torch::torch_tensor(values, device = "cpu")

R/sits_tempcnn.R

+4-21
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
#' if (sits_run_examples()) {
6161
#' # create a TempCNN model
6262
#' torch_model <- sits_train(samples_modis_ndvi,
63-
#' sits_tempcnn(verbose = TRUE))
63+
#' sits_tempcnn(epochs = 20, verbose = TRUE))
6464
#' # plot the model
6565
#' plot(torch_model)
6666
#' # create a data cube from local files
@@ -290,10 +290,7 @@ sits_tempcnn <- function(samples = NULL,
290290
}
291291
)
292292
# torch 12.0 not working with Apple MPS
293-
if (torch::backends_mps_is_available())
294-
cpu_train <- TRUE
295-
else
296-
cpu_train <- FALSE
293+
cpu_train <- .torch_mps_train()
297294
# Train the model using luz
298295
torch_model <-
299296
luz::setup(
@@ -362,22 +359,8 @@ sits_tempcnn <- function(samples = NULL,
362359
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
363360
)
364361
# if CUDA is available, transform to torch data set
365-
# Load into GPU
366-
if (.torch_has_cuda() || .torch_has_mps()) {
367-
values <- .as_dataset(values)
368-
# We need to transform in a dataloader to use the batch size
369-
values <- torch::dataloader(
370-
values, batch_size = 2^15
371-
)
372-
# Do GPU classification
373-
values <- .try(
374-
stats::predict(object = torch_model, values),
375-
.msg_error = .conf("messages", ".check_gpu_memory_size")
376-
)
377-
} else {
378-
# Do CPU classification
379-
values <- stats::predict(object = torch_model, values)
380-
}
362+
# predict using CPU or GPU depending on machine
363+
values <- .torch_predict(values, torch_model)
381364
# Convert to tensor cpu to support GPU processing
382365
values <- torch::as_array(
383366
x = torch::torch_tensor(values, device = "cpu")

man/sits_classify.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sits_mlp.Rd

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sits_tempcnn.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)