Skip to content

Commit e6e1526

Browse files
include support for MPS
1 parent 54af3e3 commit e6e1526

9 files changed

+42
-24
lines changed

R/api_classify.R

+7-3
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@
196196
verbose = verbose
197197
)
198198
# Clean torch allocations
199-
if (.is_torch_model(ml_model)) {
199+
if (.is_torch_model(ml_model) && .torch_has_cuda()) {
200200
torch::cuda_empty_cache()
201201
}
202202
# Return probs tile
@@ -527,7 +527,9 @@
527527
)
528528
}
529529
# choose between GPU and CPU
530-
if (inherits(ml_model, "torch_model") && torch::cuda_is_available())
530+
if (inherits(ml_model, "torch_model") &&
531+
(.torch_has_cuda() || .torch_has_mps())
532+
)
531533
prediction <- .classify_ts_gpu(
532534
pred = pred,
533535
ml_model = ml_model,
@@ -629,7 +631,9 @@
629631
# Return classification
630632
values <- tibble::tibble(data.frame(values))
631633
# Clean torch cache
632-
torch::cuda_empty_cache()
634+
if (.is_torch_model(ml_model) && .torch_has_cuda()) {
635+
torch::cuda_empty_cache()
636+
}
633637
return(values)
634638
})
635639
return(prediction)

R/api_torch.R

+16-5
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,25 @@
366366
)
367367

368368
.is_torch_model <- function(ml_model) {
369-
inherits(ml_model, "torch_model") && torch::cuda_is_available()
369+
inherits(ml_model, "torch_model")
370+
}
371+
.torch_has_cuda <- function(){
372+
torch::cuda_is_available()
373+
}
374+
.torch_has_mps <- function(){
375+
torch::backends_mps_is_available()
370376
}
371377

372378
.torch_mem_info <- function() {
373-
# Get memory summary
374-
mem_sum <- torch::cuda_memory_stats()
375-
# Return current memory info in GB
376-
mem_sum[["allocated_bytes"]][["all"]][["current"]] / 10^9
379+
if (.torch_has_cuda()){
380+
# Get memory summary
381+
mem_sum <- torch::cuda_memory_stats()
382+
# Return current memory info in GB
383+
mem_sum[["allocated_bytes"]][["all"]][["current"]] / 10^9
384+
} else {
385+
mem_sum <- 0
386+
}
387+
return(mem_sum)
377388
}
378389

379390
.as_dataset <- torch::dataset(

R/sits_classify.R

+7-4
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ sits_classify.sits <- function(data,
170170
filter_fn = NULL,
171171
impute_fn = impute_linear(),
172172
multicores = 2L,
173-
gpu_memory = 16,
173+
gpu_memory = 4,
174174
progress = TRUE) {
175175
# set caller for error messages
176176
.check_set_caller("sits_classify_sits")
@@ -207,13 +207,14 @@ sits_classify.raster_cube <- function(data,
207207
end_date = NULL,
208208
memsize = 8L,
209209
multicores = 2L,
210-
gpu_memory = 16,
210+
gpu_memory = 4,
211211
output_dir,
212212
version = "v1",
213213
verbose = FALSE,
214214
progress = TRUE) {
215215
# set caller for error messages
216216
.check_set_caller("sits_classify_raster")
217+
# reduce GPU memory for MPS
217218
# preconditions
218219
.check_is_raster_cube(data)
219220
.check_that(.cube_is_regular(data))
@@ -227,7 +228,8 @@ sits_classify.raster_cube <- function(data,
227228
# Get default proc bloat
228229
proc_bloat <- .conf("processing_bloat_cpu")
229230
# If we using the GPU, gpu_memory parameter needs to be specified
230-
if (.is_torch_model(ml_model)) {
231+
if (.is_torch_model(ml_model) &&
232+
(.torch_has_cuda() || .torch_has_mps())) {
231233
.check_int_parameter(gpu_memory, min = 1, max = 16384,
232234
msg = .conf("messages", ".check_gpu_memory")
233235
)
@@ -375,7 +377,8 @@ sits_classify.segs_cube <- function(data,
375377
.check_progress(progress)
376378
proc_bloat <- .conf("processing_bloat_seg_class")
377379
# If we using the GPU, gpu_memory parameter needs to be specified
378-
if (.is_torch_model(ml_model)) {
380+
if (.is_torch_model(ml_model) &&
381+
(.torch_has_cuda() || .torch_has_mps())) {
379382
.check_int_parameter(gpu_memory, min = 1, max = 16384,
380383
msg = .conf("messages", ".check_gpu_memory")
381384
)

R/sits_lighttae.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ sits_lighttae <- function(samples = NULL,
344344
)
345345
# if CUDA is available, transform to torch data set
346346
# Load into GPU
347-
if (torch::cuda_is_available()) {
347+
if (.torch_has_cuda() || .torch_has_mps()) {
348348
values <- .as_dataset(values)
349349
# We need to transform in a dataloader to use the batch size
350350
values <- torch::dataloader(

R/sits_mlp.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ sits_mlp <- function(samples = NULL,
294294
values <- as.matrix(values)
295295
# if CUDA is available, transform to torch data set
296296
# Load into GPU
297-
if (torch::cuda_is_available()) {
297+
if (.torch_has_cuda() || .torch_has_mps()) {
298298
values <- .as_dataset(values)
299299
# We need to transform in a dataloader to use the batch size
300300
values <- torch::dataloader(

R/sits_resnet.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,11 @@ sits_resnet <- function(samples = NULL,
322322
self$softmax()
323323
}
324324
)
325-
# torch 12.0 not working with Apple MPS
326-
if (torch::backends_mps_is_available())
327-
cpu_train <- TRUE
325+
# # torch 12.0 not working with Apple MPS
326+
if (.torch_has_mps())
327+
cpu_train <- TRUE
328328
else
329-
cpu_train <- FALSE
329+
cpu_train <- FALSE
330330
# train the model using luz
331331
torch_model <-
332332
luz::setup(
@@ -392,7 +392,7 @@ sits_resnet <- function(samples = NULL,
392392
)
393393
# if CUDA is available, transform to torch data set
394394
# Load into GPU
395-
if (torch::cuda_is_available()) {
395+
if (.torch_has_cuda() || .torch_has_mps()) {
396396
values <- .as_dataset(values)
397397
# We need to transform in a dataloader to use the batch size
398398
values <- torch::dataloader(

R/sits_tae.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ sits_tae <- function(samples = NULL,
244244
}
245245
)
246246
# torch 12.0 not working with Apple MPS
247-
if (torch::backends_mps_is_available())
247+
if (.torch_has_mps())
248248
cpu_train <- TRUE
249249
else
250250
cpu_train <- FALSE
@@ -311,7 +311,7 @@ sits_tae <- function(samples = NULL,
311311
)
312312
# if CUDA is available, transform to torch data set
313313
# Load into GPU
314-
if (torch::cuda_is_available()) {
314+
if (.torch_has_cuda() || .torch_has_mps()) {
315315
values <- .as_dataset(values)
316316
# We need to transform in a dataloader to use the batch size
317317
values <- torch::dataloader(

R/sits_tempcnn.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ sits_tempcnn <- function(samples = NULL,
363363
)
364364
# if CUDA is available, transform to torch data set
365365
# Load into GPU
366-
if (torch::cuda_is_available()) {
366+
if (.torch_has_cuda() || .torch_has_mps()) {
367367
values <- .as_dataset(values)
368368
# We need to transform in a dataloader to use the batch size
369369
values <- torch::dataloader(

man/sits_classify.Rd

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)