Skip to content

Commit abb39b5

Browse files
fix M3 chip torch bug
1 parent 7e98e88 commit abb39b5

12 files changed

+33
-1
lines changed

R/sits_lighttae.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ sits_lighttae <- function(samples = NULL,
267267
return(out)
268268
}
269269
)
270+
# torch 12.0 not working with Apple MPS
271+
if (torch::backends_mps_is_available())
272+
cpu_train <- TRUE
273+
else
274+
cpu_train <- FALSE
270275
# Train the model using luz
271276
torch_model <-
272277
luz::setup(
@@ -300,6 +305,7 @@ sits_lighttae <- function(samples = NULL,
300305
gamma = lr_decay_rate
301306
)
302307
),
308+
accelerator = luz::accelerator(cpu = cpu_train),
303309
dataloader_options = list(batch_size = batch_size),
304310
verbose = verbose
305311
)

R/sits_mlp.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ sits_mlp <- function(samples = NULL,
237237
self$model(x)
238238
}
239239
)
240+
# torch 12.0 not working with Apple MPS
241+
if (torch::backends_mps_is_available())
242+
cpu_train <- TRUE
243+
else
244+
cpu_train <- FALSE
240245
# Train the model using luz
241246
torch_model <-
242247
luz::setup(
@@ -262,6 +267,8 @@ sits_mlp <- function(samples = NULL,
262267
patience = patience,
263268
min_delta = min_delta
264269
)),
270+
dataloader_options = list(batch_size = batch_size),
271+
accelerator = luz::accelerator(cpu = cpu_train),
265272
verbose = verbose
266273
)
267274
# Serialize model

R/sits_resnet.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ sits_resnet <- function(samples = NULL,
319319
self$softmax()
320320
}
321321
)
322+
# torch 12.0 not working with Apple MPS
323+
if (torch::backends_mps_is_available())
324+
cpu_train <- TRUE
325+
else
326+
cpu_train <- FALSE
322327
# train the model using luz
323328
torch_model <-
324329
luz::setup(
@@ -354,6 +359,7 @@ sits_resnet <- function(samples = NULL,
354359
gamma = lr_decay_rate
355360
)
356361
),
362+
accelerator = luz::accelerator(cpu = cpu_train),
357363
dataloader_options = list(batch_size = batch_size),
358364
verbose = verbose
359365
)

R/sits_tae.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ sits_tae <- function(samples = NULL,
240240
return(x)
241241
}
242242
)
243+
# torch 12.0 not working with Apple MPS
244+
if (torch::backends_mps_is_available())
245+
cpu_train <- TRUE
246+
else
247+
cpu_train <- FALSE
243248
# train the model using luz
244249
torch_model <-
245250
luz::setup(
@@ -273,6 +278,7 @@ sits_tae <- function(samples = NULL,
273278
gamma = lr_decay_rate
274279
)
275280
),
281+
accelerator = luz::accelerator(cpu = cpu_train),
276282
dataloader_options = list(batch_size = batch_size),
277283
verbose = verbose
278284
)

R/sits_tempcnn.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
#' @examples
6060
#' if (sits_run_examples()) {
6161
#' # create a TempCNN model
62-
#' torch_model <- sits_train(samples_modis_ndvi, sits_tempcnn())
62+
#' torch_model <- sits_train(samples_modis_ndvi,
63+
#' sits_tempcnn(verbose = TRUE))
6364
#' # plot the model
6465
#' plot(torch_model)
6566
#' # create a data cube from local files
@@ -285,6 +286,11 @@ sits_tempcnn <- function(samples = NULL,
285286
self$softmax()
286287
}
287288
)
289+
# torch 12.0 not working with Apple MPS
290+
if (torch::backends_mps_is_available())
291+
cpu_train <- TRUE
292+
else
293+
cpu_train <- FALSE
288294
# Train the model using luz
289295
torch_model <-
290296
luz::setup(
@@ -323,6 +329,7 @@ sits_tempcnn <- function(samples = NULL,
323329
gamma = lr_decay_rate
324330
)
325331
),
332+
accelerator = luz::accelerator(cpu = cpu_train),
326333
dataloader_options = list(batch_size = batch_size),
327334
verbose = verbose
328335
)

mnist-cnn.pt

10.5 MB
Binary file not shown.

mnist/mnist/processed/test.rds

2.82 MB
Binary file not shown.

mnist/mnist/processed/training.rds

16.9 MB
Binary file not shown.
1.57 MB
Binary file not shown.
4.44 KB
Binary file not shown.

0 commit comments

Comments
 (0)