Skip to content

Commit abb39b5

Browse files
fix M3 chip torch bug
1 parent 7e98e88 commit abb39b5

12 files changed

+33
-1
lines changed

R/sits_lighttae.R

+6
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ sits_lighttae <- function(samples = NULL,
267267
return(out)
268268
}
269269
)
270+
# torch 12.0 not working with Apple MPS
271+
if (torch::backends_mps_is_available())
272+
cpu_train <- TRUE
273+
else
274+
cpu_train <- FALSE
270275
# Train the model using luz
271276
torch_model <-
272277
luz::setup(
@@ -300,6 +305,7 @@ sits_lighttae <- function(samples = NULL,
300305
gamma = lr_decay_rate
301306
)
302307
),
308+
accelerator = luz::accelerator(cpu = cpu_train),
303309
dataloader_options = list(batch_size = batch_size),
304310
verbose = verbose
305311
)

R/sits_mlp.R

+7
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ sits_mlp <- function(samples = NULL,
237237
self$model(x)
238238
}
239239
)
240+
# torch 12.0 not working with Apple MPS
241+
if (torch::backends_mps_is_available())
242+
cpu_train <- TRUE
243+
else
244+
cpu_train <- FALSE
240245
# Train the model using luz
241246
torch_model <-
242247
luz::setup(
@@ -262,6 +267,8 @@ sits_mlp <- function(samples = NULL,
262267
patience = patience,
263268
min_delta = min_delta
264269
)),
270+
dataloader_options = list(batch_size = batch_size),
271+
accelerator = luz::accelerator(cpu = cpu_train),
265272
verbose = verbose
266273
)
267274
# Serialize model

R/sits_resnet.R

+6
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ sits_resnet <- function(samples = NULL,
319319
self$softmax()
320320
}
321321
)
322+
# torch 12.0 not working with Apple MPS
323+
if (torch::backends_mps_is_available())
324+
cpu_train <- TRUE
325+
else
326+
cpu_train <- FALSE
322327
# train the model using luz
323328
torch_model <-
324329
luz::setup(
@@ -354,6 +359,7 @@ sits_resnet <- function(samples = NULL,
354359
gamma = lr_decay_rate
355360
)
356361
),
362+
accelerator = luz::accelerator(cpu = cpu_train),
357363
dataloader_options = list(batch_size = batch_size),
358364
verbose = verbose
359365
)

R/sits_tae.R

+6
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ sits_tae <- function(samples = NULL,
240240
return(x)
241241
}
242242
)
243+
# torch 12.0 not working with Apple MPS
244+
if (torch::backends_mps_is_available())
245+
cpu_train <- TRUE
246+
else
247+
cpu_train <- FALSE
243248
# train the model using luz
244249
torch_model <-
245250
luz::setup(
@@ -273,6 +278,7 @@ sits_tae <- function(samples = NULL,
273278
gamma = lr_decay_rate
274279
)
275280
),
281+
accelerator = luz::accelerator(cpu = cpu_train),
276282
dataloader_options = list(batch_size = batch_size),
277283
verbose = verbose
278284
)

R/sits_tempcnn.R

+8-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
#' @examples
6060
#' if (sits_run_examples()) {
6161
#' # create a TempCNN model
62-
#' torch_model <- sits_train(samples_modis_ndvi, sits_tempcnn())
62+
#' torch_model <- sits_train(samples_modis_ndvi,
63+
#' sits_tempcnn(verbose = TRUE))
6364
#' # plot the model
6465
#' plot(torch_model)
6566
#' # create a data cube from local files
@@ -285,6 +286,11 @@ sits_tempcnn <- function(samples = NULL,
285286
self$softmax()
286287
}
287288
)
289+
# torch 12.0 not working with Apple MPS
290+
if (torch::backends_mps_is_available())
291+
cpu_train <- TRUE
292+
else
293+
cpu_train <- FALSE
288294
# Train the model using luz
289295
torch_model <-
290296
luz::setup(
@@ -323,6 +329,7 @@ sits_tempcnn <- function(samples = NULL,
323329
gamma = lr_decay_rate
324330
)
325331
),
332+
accelerator = luz::accelerator(cpu = cpu_train),
326333
dataloader_options = list(batch_size = batch_size),
327334
verbose = verbose
328335
)

mnist-cnn.pt

10.5 MB
Binary file not shown.

mnist/mnist/processed/test.rds

2.82 MB
Binary file not shown.

mnist/mnist/processed/training.rds

16.9 MB
Binary file not shown.
1.57 MB
Binary file not shown.
4.44 KB
Binary file not shown.
9.45 MB
Binary file not shown.
28.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)