Skip to content

Commit 9eaaa60

Browse files
support for training rfor and xgboost with DEM base cubes
1 parent f1de4b6 commit 9eaaa60

17 files changed

+57
-18
lines changed

NAMESPACE

+3
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ S3method(.samples_alloc_strata,class_cube)
136136
S3method(.samples_alloc_strata,class_vector_cube)
137137
S3method(.samples_bands,sits)
138138
S3method(.samples_bands,sits_base)
139+
S3method(.samples_select_bands,sits)
140+
S3method(.samples_select_bands,sits_base)
139141
S3method(.slice_dfr,numeric)
140142
S3method(.source_collection_access_test,"mpc_cube_sentinel-1-grd")
141143
S3method(.source_collection_access_test,cdse_cube)
@@ -213,6 +215,7 @@ S3method(.tile_as_sf,raster_cube)
213215
S3method(.tile_band_conf,default)
214216
S3method(.tile_band_conf,derived_cube)
215217
S3method(.tile_band_conf,eo_cube)
218+
S3method(.tile_bands,base_raster_cube)
216219
S3method(.tile_bands,default)
217220
S3method(.tile_bands,raster_cube)
218221
S3method(.tile_bbox,default)

R/api_check.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -1297,7 +1297,7 @@
12971297
len_min = 1,
12981298
len_max = 1
12991299
)
1300-
output_dir <- .file_normalize(output_dir)
1300+
output_dir <- .file_path_expand(output_dir)
13011301
.check_file(output_dir)
13021302
return(invisible(output_dir))
13031303
}

R/api_download.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
.download_asset <- function(asset, res, sf_roi, n_tries, output_dir,
1111
progress, ...) {
1212
# Get all paths and expand
13-
file <- .file_normalize(.tile_path(asset))
13+
file <- .file_path_expand(.tile_path(asset))
1414
# Create a list of user parameters as gdal format
1515
gdal_params <- .gdal_format_params(
1616
asset = asset,

R/api_file.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#' @noRd
3232
#' @param file File name
3333
#' @returns File base name with path expanded
34-
.file_normalize <- function(file) {
34+
.file_path_expand <- function(file) {
3535
path.expand(file)
3636
}
3737
#' @title Build a file path
@@ -52,7 +52,7 @@
5252
}
5353
if (.has(output_dir)) {
5454
output_dir <- gsub("[/]*$", "", output_dir)
55-
output_dir <- .file_normalize(output_dir)
55+
output_dir <- .file_path_expand(output_dir)
5656
if (!dir.exists(output_dir) && create_dir) {
5757
dir.create(output_dir, recursive = TRUE)
5858
}

R/api_file_info.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ NULL
9090
.check_set_caller(".fi_eo_from_files")
9191
# precondition
9292
.check_that(length(files) == length(bands))
93-
files <- .file_normalize(files)
93+
files <- .file_path_expand(files)
9494
r_obj <- .raster_open_rast(files)
9595
.fi_eo(
9696
fid = fid[[1]],
@@ -148,7 +148,7 @@ NULL
148148
#' @param start_date start date of the image
149149
#' @param end_date end date of the image
150150
.fi_derived_from_file <- function(file, band, start_date, end_date) {
151-
file <- .file_normalize(file)
151+
file <- .file_path_expand(file)
152152
r_obj <- .raster_open_rast(file)
153153
.fi_derived(
154154
band = band,

R/api_raster.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@
765765
missing_value) {
766766
# Create an empty image template
767767
gdalUtilities::gdal_translate(
768-
src_dataset = .file_normalize(base_file),
769-
dst_dataset = .file_normalize(out_file),
768+
src_dataset = .file_path_expand(base_file),
769+
dst_dataset = .file_path_expand(out_file),
770770
ot = .raster_gdal_datatype(data_type),
771771
of = "GTiff",
772772
b = rep(1, nlayers),
@@ -815,13 +815,13 @@
815815
# for each file merge blocks
816816
for (i in seq_along(out_files)) {
817817
# Expand paths for out_file
818-
out_file <- .file_normalize(out_files[[i]])
818+
out_file <- .file_path_expand(out_files[[i]])
819819
# Check if out_file does not exist
820820
.check_that(!file.exists(out_file))
821821
# Get file paths
822822
merge_files <- purrr::map_chr(block_files, `[[`, i)
823823
# Expand paths for block_files
824-
merge_files <- .file_normalize(merge_files)
824+
merge_files <- .file_path_expand(merge_files)
825825
# check if block_files length is at least one
826826
.check_file(
827827
x = merge_files,

R/api_raster_terra.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
#' @export
9696
.raster_open_rast.terra <- function(file, ...) {
9797
r_obj <- suppressWarnings(
98-
terra::rast(x = .file_normalize(file), ...)
98+
terra::rast(x = .file_path_expand(file), ...)
9999
)
100100
.check_null_parameter(r_obj)
101101
# remove gain and offset applied by terra

R/api_samples.R

+20-3
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,16 @@
9696
#' @export
9797
.samples_bands.sits <- function(samples) {
9898
# Bands of the first sample governs whole samples data
99-
setdiff(names(.samples_ts(samples)), "Index")
99+
bands <- setdiff(names(.samples_ts(samples)), "Index")
100+
return(bands)
100101
}
101102
#' @export
102103
.samples_bands.sits_base <- function(samples) {
103104
# Bands of the first sample governs whole samples data
104105
ts_bands <- .samples_bands.sits(samples)
105106
base_bands <- .samples_bands_base(samples)
106107
bands <- c(ts_bands, base_bands)
108+
return(bands)
107109
}
108110
#' @title Get bands of base data for samples
109111
#' @noRd
@@ -128,8 +130,23 @@
128130
#' @param bands Bands to be selected
129131
#' @return Time series samples with the selected bands
130132
.samples_select_bands <- function(samples, bands) {
133+
UseMethod(".samples_select_bands", samples)
134+
}
135+
#' @export
136+
.samples_select_bands.sits <- function(samples, bands) {
131137
# Filter samples
132-
.ts(samples) <- .ts_select_bands(ts = .ts(samples), bands = bands)
138+
.ts(samples) <- .ts_select_bands(ts = .ts(samples),
139+
bands = bands)
140+
# Return samples
141+
samples
142+
}
143+
#' @export
144+
.samples_select_bands.sits_base <- function(samples, bands) {
145+
ts_bands <- .samples_bands.sits(samples)
146+
ts_select_bands <- bands[bands %in% ts_bands]
147+
# Filter time series samples
148+
.ts(samples) <- .ts_select_bands(ts = .ts(samples),
149+
bands = ts_select_bands)
133150
# Return samples
134151
samples
135152
}
@@ -192,7 +209,7 @@
192209
# Get all time series
193210
preds <- .samples_ts(samples)
194211
# Select attributes
195-
preds <- preds[.samples_bands(samples)]
212+
preds <- preds[.samples_bands.sits(samples)]
196213
# Compute stats
197214
q02 <- apply(preds, 2, stats::quantile, probs = 0.02, na.rm = TRUE)
198215
q98 <- apply(preds, 2, stats::quantile, probs = 0.98, na.rm = TRUE)

R/api_vector_info.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ NULL
4545
}
4646

4747
.vi_segment_from_file <- function(file, base_tile, band, start_date, end_date) {
48-
file <- .file_normalize(file)
48+
file <- .file_path_expand(file)
4949
v_obj <- .vector_read_vec(file_path = file)
5050
bbox <- .vector_bbox(v_obj)
5151
.vi_derived(

R/sits_lighttae.R

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ sits_lighttae <- function(samples = NULL,
117117
.check_set_caller("sits_lighttae")
118118
# Function that trains a torch model based on samples
119119
train_fun <- function(samples) {
120+
# does not support working with DEM or other base data
121+
if (inherits(samples, "sits_base"))
122+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
120123
# Avoid add a global variable for 'self'
121124
self <- NULL
122125
# Verifies if 'torch' and 'luz' packages is installed

R/sits_machine_learning.R

+3
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ sits_svm <- function(samples = NULL, formula = sits_formula_linear(),
165165
.check_set_caller("sits_svm")
166166
# Function that trains a support vector machine model
167167
train_fun <- function(samples) {
168+
# does not support working with DEM or other base data
169+
if (inherits(samples, "sits_base"))
170+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
168171
# Verifies if e1071 package is installed
169172
.check_require_packages("e1071")
170173
# Get labels (used later to ensure column order in result matrix)

R/sits_mlp.R

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ sits_mlp <- function(samples = NULL,
108108
.check_set_caller("sits_mlp")
109109
# Function that trains a torch model based on samples
110110
train_fun <- function(samples) {
111+
# does not support working with DEM or other base data
112+
if (inherits(samples, "sits_base"))
113+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
111114
# Avoid add a global variable for 'self'
112115
self <- NULL
113116
# Verifies if 'torch' and 'luz' packages is installed

R/sits_regularize.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ sits_regularize.raster_cube <- function(cube, ...,
120120
.check_raster_cube_files(cube)
121121
.check_period(period)
122122
.check_num_parameter(res, exclusive_min = 0)
123-
output_dir <- .file_normalize(output_dir)
123+
output_dir <- .file_path_expand(output_dir)
124124
.check_output_dir(output_dir)
125125
.check_num_parameter(multicores, min = 1, max = 2048)
126126
.check_progress(progress)
@@ -175,7 +175,7 @@ sits_regularize.sar_cube <- function(cube, ...,
175175
.check_raster_cube_files(cube)
176176
.check_period(period)
177177
.check_num_parameter(res, exclusive_min = 0)
178-
output_dir <- .file_normalize(output_dir)
178+
output_dir <- .file_path_expand(output_dir)
179179
.check_output_dir(output_dir)
180180
.check_num_parameter(multicores, min = 1, max = 2048)
181181
.check_progress(progress)
@@ -218,7 +218,7 @@ sits_regularize.dem_cube <- function(cube, ...,
218218
# Preconditions
219219
.check_raster_cube_files(cube)
220220
.check_num_parameter(res, exclusive_min = 0)
221-
output_dir <- .file_normalize(output_dir)
221+
output_dir <- .file_path_expand(output_dir)
222222
.check_output_dir(output_dir)
223223
.check_num_parameter(multicores, min = 1, max = 2048)
224224
.check_progress(progress)

R/sits_resnet.R

+3
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ sits_resnet <- function(samples = NULL,
126126
.check_set_caller("sits_resnet")
127127
# Function that trains a torch model based on samples
128128
train_fun <- function(samples) {
129+
# does not support working with DEM or other base data
130+
if (inherits(samples, "sits_base"))
131+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
129132
# Avoid add a global variable for 'self'
130133
self <- NULL
131134
# Verifies if 'torch' and 'luz' packages is installed

R/sits_tae.R

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ sits_tae <- function(samples = NULL,
107107
.check_set_caller("sits_tae")
108108
# Function that trains a torch model based on samples
109109
train_fun <- function(samples) {
110+
# does not support working with DEM or other base data
111+
if (inherits(samples, "sits_base"))
112+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
110113
# Avoid add a global variable for 'self'
111114
self <- NULL
112115
# Verifies if 'torch' and 'luz' packages is installed

R/sits_tempcnn.R

+3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ sits_tempcnn <- function(samples = NULL,
114114
.check_set_caller("sits_tempcnn")
115115
# Function that trains a torch model based on samples
116116
train_fun <- function(samples) {
117+
# does not support working with DEM or other base data
118+
if (inherits(samples, "sits_base"))
119+
stop(.conf("messages", "sits_train_base_data"), call. = FALSE)
117120
# Avoid add a global variable for 'self'
118121
self <- NULL
119122
# Verifies if 'torch' and 'luz' packages is installed

inst/extdata/config_messages.yml

+1
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ sits_stratified_sampling_shp_save: "saved allocation in shapefile"
430430
sits_svm: "wrong input parameters - see example in documentation"
431431
sits_tae: "wrong input parameters - see example in documentation"
432432
sits_tempcnn: "wrong input parameters - see example in documentation"
433+
sits_train_base_data: "training samples with DEM or other base data is only supported by random forest and xgboost methods"
433434
sits_timeline_raster_cube: "cube is not regular, returning all timelines"
434435
sits_timeline_default: "input should be a set of training samples or a data cube"
435436
sits_to_csv: "invalid CSV file to be written to"

0 commit comments

Comments
 (0)