Skip to content

Commit 315173b

Browse files
fix support for MPS Apple
1 parent dcf508a commit 315173b

10 files changed

+111
-514
lines changed

DESCRIPTION

-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ Collate:
246246
'sits_reclassify.R'
247247
'sits_reduce.R'
248248
'sits_regularize.R'
249-
'sits_resnet.R'
250249
'sits_sample_functions.R'
251250
'sits_segmentation.R'
252251
'sits_select.R'

R/api_torch.R

+9-10
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,7 @@
372372
torch::cuda_is_available()
373373
}
374374
.torch_has_mps <- function(){
375-
# torch::backends_mps_is_available()
376-
return(FALSE)
375+
torch::backends_mps_is_available()
377376
}
378377

379378
.torch_mem_info <- function() {
@@ -400,8 +399,9 @@
400399
#' @return TRUE/FALSE
401400
#'
402401
.torch_gpu_enabled <- function(ml_model){
403-
gpu_enabled <- (inherits(ml_model, "torch_model") &&
404-
(.torch_has_cuda() || .torch_has_mps())
402+
gpu_enabled <- (
403+
inherits(ml_model, "torch_model") &&
404+
.torch_has_cuda()
405405
)
406406
return(gpu_enabled)
407407
}
@@ -439,7 +439,7 @@
439439
return(values)
440440
}
441441
#' @title Use GPU or CPU train for MPS Apple
442-
#' @name .torch_mps_train
442+
#' @name .torch_cpu_train
443443
#' @author Gilberto Camara, \email{gilberto.camara@@inpe.br}
444444
#' @keywords internal
445445
#' @noRd
@@ -448,12 +448,11 @@
448448
#'
449449
#' @return TRUE/FALSE
450450
#'
451-
.torch_mps_train <- function() {
452-
if (torch::backends_mps_is_available())
453-
cpu_train <- TRUE
454-
else
451+
.torch_cpu_train <- function() {
452+
if (torch::cuda_is_available())
455453
cpu_train <- FALSE
456-
454+
else
455+
cpu_train <- TRUE
457456
return(cpu_train)
458457
}
459458

R/sits_classify.R

+8-1
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,18 @@ sits_classify.raster_cube <- function(data,
233233
)
234234
# Calculate available memory from GPU
235235
memsize <- floor(gpu_memory - .torch_mem_info())
236-
.check_int_parameter(memsize, min = 2,
236+
.check_int_parameter(memsize, min = 1,
237237
msg = .conf("messages", ".check_gpu_memory_size")
238238
)
239239
proc_bloat <- .conf("processing_bloat_gpu")
240240
}
241+
# avoid memory race in Apple MPS
242+
if(.torch_has_mps()){
243+
memsize <- 1
244+
gpu_memory <- 1
245+
}
246+
# save memsize for latter use
247+
sits_env[["gpu_memory"]] <- gpu_memory
241248
# Spatial filter
242249
if (.has(roi)) {
243250
roi <- .roi_as_sf(roi)

R/sits_lighttae.R

+21-12
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ sits_lighttae <- function(samples = NULL,
271271
}
272272
)
273273
# torch 12.0 not working with Apple MPS
274-
cpu_train <- .torch_mps_train()
274+
cpu_train <- .torch_cpu_train()
275275
# Train the model using luz
276276
torch_model <-
277277
luz::setup(
@@ -339,17 +339,26 @@ sits_lighttae <- function(samples = NULL,
339339
values <- array(
340340
data = as.matrix(values), dim = c(n_samples, n_times, n_bands)
341341
)
342-
# Predict using GPU if available
343-
# If not, use CPU
344-
values <- .torch_predict(values, torch_model)
345-
# Convert to tensor CPU
346-
values <- torch::as_array(
347-
x = torch::torch_tensor(values, device = "cpu")
348-
)
349-
# Are the results consistent with the data input?
350-
.check_processed_values(
351-
values = values, input_pixels = input_pixels
352-
)
342+
# Load into GPU
343+
if (.torch_has_cuda()){
344+
# set the batch size according to the GPU memory
345+
gpu_memory <- sits_env[["gpu_memory"]]
346+
b_size <- 2^gpu_memory
347+
# transfor the input array to a dataset
348+
values <- .as_dataset(values)
349+
# To the data set to a torcj transform in a dataloader to use the batch size
350+
values <- torch::dataloader(values, batch_size = b_size)
351+
# Do GPU classification with dataloader
352+
values <- .try(
353+
stats::predict(object = torch_model, values),
354+
.msg_error = .conf("messages", ".check_gpu_memory_size")
355+
)
356+
} else {
357+
# Do classification without dataloader
358+
values <- stats::predict(object = torch_model, values)
359+
}
360+
# Convert from tensor to array
361+
values <- torch::as_array(values)
353362
# Update the columns names to labels
354363
colnames(values) <- labels
355364
return(values)

R/sits_mlp.R

+25-13
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ sits_mlp <- function(samples = NULL,
205205
test_y <- unname(code_labels[.pred_references(test_samples)])
206206
# Set torch seed
207207
torch::torch_manual_seed(sample.int(10^5, 1))
208-
208+
# train with CPU or GPU?
209+
cpu_train <- .torch_cpu_train()
209210
# Define the MLP architecture
210211
mlp_model <- torch::nn_module(
211212
initialize = function(num_pred, layers, dropout_rates, y_dim) {
@@ -241,8 +242,8 @@ sits_mlp <- function(samples = NULL,
241242
self$model(x)
242243
}
243244
)
244-
# torch 12.0 not working with Apple MPS
245-
cpu_train <- .torch_mps_train()
245+
# Train with CPU or GPU?
246+
cpu_train <- .torch_cpu_train()
246247
# Train the model using luz
247248
torch_model <-
248249
luz::setup(
@@ -290,16 +291,27 @@ sits_mlp <- function(samples = NULL,
290291
values <- .pred_normalize(pred = values, stats = ml_stats)
291292
# Transform input into matrix
292293
values <- as.matrix(values)
293-
# predict using CPU or GPU depending on machine
294-
values <- .torch_predict(values, torch_model)
295-
# Convert to tensor cpu to support GPU processing
296-
values <- torch::as_array(
297-
x = torch::torch_tensor(values, device = "cpu")
298-
)
299-
# Are the results consistent with the data input?
300-
.check_processed_values(
301-
values = values, input_pixels = input_pixels
302-
)
294+
# if CUDA is available, transform to torch data set
295+
# Load into GPU
296+
if (.torch_has_cuda()){
297+
# set the batch size according to the GPU memory
298+
gpu_memory <- sits_env[["gpu_memory"]]
299+
b_size <- 2^gpu_memory
300+
# transfor the input array to a dataset
301+
values <- .as_dataset(values)
302+
# To the data set to a torcj transform in a dataloader to use the batch size
303+
values <- torch::dataloader(values, batch_size = b_size)
304+
# Do GPU classification with dataloader
305+
values <- .try(
306+
stats::predict(object = torch_model, values),
307+
.msg_error = .conf("messages", ".check_gpu_memory_size")
308+
)
309+
} else {
310+
# Do classification without dataloader
311+
values <- stats::predict(object = torch_model, values)
312+
}
313+
# Convert from tensor to array
314+
values <- torch::as_array(values)
303315
# Update the columns names to labels
304316
colnames(values) <- labels
305317
return(values)

0 commit comments

Comments
 (0)