Skip to content

Commit 22167ef

Browse files
Merge pull request #1482 from rolfsimoes/dev
Bug fixes and `sits_uncertainty_sampling()` improvements
2 parents 7d5d284 + 5202b31 commit 22167ef

8 files changed

Lines changed: 287 additions & 89 deletions

File tree

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ S3method(.source_collection_access_test,ogh_cube)
157157
S3method(.source_collection_access_test,stac_cube)
158158
S3method(.source_collection_access_test,usgs_cube)
159159
S3method(.source_cube,stac_cube)
160+
S3method(.source_filter_tiles,"bdc_cube_landsat-1y")
161+
S3method(.source_filter_tiles,"bdc_cube_landsat-2m")
160162
S3method(.source_filter_tiles,"cdse_os_cube_sentinel-1-rtc")
161163
S3method(.source_filter_tiles,"deafrica_cube_sentinel-1-rtc")
162164
S3method(.source_filter_tiles,"mpc_cube_cop-dem-glo-30")
@@ -179,6 +181,8 @@ S3method(.source_items_bands_select,stac_cube)
179181
S3method(.source_items_cube,stac_cube)
180182
S3method(.source_items_fid,stac_cube)
181183
S3method(.source_items_new,"aws_cube_landsat-c2-l2")
184+
S3method(.source_items_new,"bdc_cube_landsat-1y")
185+
S3method(.source_items_new,"bdc_cube_landsat-2m")
182186
S3method(.source_items_new,"deafrica_cube_sentinel-1-rtc")
183187
S3method(.source_items_new,"deafrica_cube_sentinel-2-l2a")
184188
S3method(.source_items_new,"mpc_cube_cop-dem-glo-30")

R/api_check.R

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -573,16 +573,20 @@
573573
msg = paste0("value should be <= ", max)
574574
)
575575
# exclusive_min and exclusive_max checks
576-
.check_that(
577-
all(x > exclusive_min),
578-
local_msg = local_msg,
579-
msg = paste0("value should be > ", exclusive_min)
580-
)
581-
.check_that(
582-
all(x < exclusive_max),
583-
local_msg = local_msg,
584-
msg = paste0("value should be < ", exclusive_max)
585-
)
576+
if (is.finite(exclusive_min)) {
577+
.check_that(
578+
all(x > exclusive_min),
579+
local_msg = local_msg,
580+
msg = paste0("value should be > ", exclusive_min)
581+
)
582+
}
583+
if (is.finite(exclusive_max)) {
584+
.check_that(
585+
all(x < exclusive_max),
586+
local_msg = local_msg,
587+
msg = paste0("value should be < ", exclusive_max)
588+
)
589+
}
586590
}
587591
#' @rdname check_functions
588592
#' @keywords internal

R/api_data.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,16 @@
320320
classes <- labels[class_numbers]
321321
# insert classes into samples
322322
samples[["label"]] <- unname(classes)
323+
# Preserve start_date and end_date if they exist in input samples
324+
cols_to_select <- c("longitude", "latitude", "label")
325+
if ("start_date" %in% names(samples)) {
326+
cols_to_select <- c(cols_to_select, "start_date")
327+
}
328+
if ("end_date" %in% names(samples)) {
329+
cols_to_select <- c(cols_to_select, "end_date")
330+
}
323331
samples <- dplyr::select(
324-
samples, dplyr::all_of("longitude"),
325-
dplyr::all_of("latitude"), dplyr::all_of("label")
332+
samples, dplyr::all_of(cols_to_select)
326333
)
327334
samples
328335
})

R/api_validate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
# Call caret package to the classification statistics
2828
acc_obj <- caret::confusionMatrix(predicted, reference)
2929
# Set result class and return it
30-
.set_class(x = acc_obj, "sits_accuracy", class(acc_obj))
30+
class(acc_obj) <- c("sits_accuracy", class(acc_obj))
3131
acc_obj
3232
}

R/sits_uncertainty.R

Lines changed: 165 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,15 @@ sits_uncertainty.default <- function(cube, ...) {
205205
#' See \code{\link[sits]{sits_uncertainty}}.
206206
#' @param n Number of suggested points to be sampled per tile.
207207
#' @param min_uncert Minimum uncertainty value to select a sample.
208+
#' @param max_uncert Maximum uncertainty value to select a sample.
209+
#' Default is Inf (no upper limit).
208210
#' @param sampling_window Window size for collecting points (in pixels).
209211
#' The minimum window size is 10.
210212
#' @param multicores Number of workers for parallel processing
211213
#' (integer, min = 1, max = 2048).
212214
#' @param memsize Maximum overall memory (in GB) to run the
213215
#' function.
216+
#' @param progress Whether to show progress bars (TRUE/FALSE).
214217
#'
215218
#' @return
216219
#' A tibble with longitude and latitude in WGS84 with locations
@@ -254,83 +257,182 @@ sits_uncertainty.default <- function(cube, ...) {
254257
sits_uncertainty_sampling <- function(uncert_cube,
255258
n = 100L,
256259
min_uncert = 0.4,
260+
max_uncert = Inf,
257261
sampling_window = 10L,
258262
multicores = 2L,
259-
memsize = 4L) {
263+
memsize = 4L,
264+
progress = FALSE) {
260265
.check_set_caller("sits_uncertainty_sampling")
261266
# Pre-conditions
262267
.check_is_uncert_cube(uncert_cube)
263268
.check_int_parameter(n, min = 1L)
264269
.check_num_parameter(min_uncert, min = 0.0, max = 1.0)
270+
.check_num_parameter(max_uncert, min = 0.0)
265271
.check_int_parameter(sampling_window, min = 1L)
266-
.check_int_parameter(multicores, min = 1L)
267-
.check_int_parameter(memsize, min = 1L)
272+
.check_int_parameter(multicores, min = 1L, max = 2048L)
273+
.check_int_parameter(memsize, min = 1L, max = 16384L)
274+
progress <- .message_progress(progress)
275+
276+
# The following functions define optimal parameters for parallel processing
277+
# Get block size
278+
block <- .raster_file_blocksize(.raster_open_rast(.tile_path(uncert_cube)))
279+
# Overlapping pixels (no overlap for uncertainty sampling)
280+
overlap <- 0L
281+
# Check minimum memory needed to process one block
282+
job_block_memsize <- .jobs_block_memsize(
283+
block_size = .block_size(block = block, overlap = overlap),
284+
npaths = 1L,
285+
nbytes = 8L,
286+
proc_bloat = .conf("processing_bloat_cpu")
287+
)
288+
# Update multicores parameter
289+
multicores <- .jobs_max_multicores(
290+
job_block_memsize = job_block_memsize,
291+
memsize = memsize,
292+
multicores = multicores
293+
)
294+
# Update block parameter
295+
block <- .jobs_optimal_block(
296+
job_block_memsize = job_block_memsize,
297+
block = block,
298+
image_size = .tile_size(.tile(uncert_cube)),
299+
memsize = memsize,
300+
multicores = multicores
301+
)
302+
# Prepare parallel processing
303+
if (.parallel_start(workers = multicores)) {
304+
on.exit(.parallel_stop(), add = TRUE)
305+
}
268306
# Slide on cube tiles
269307
samples_tb <- slider::slide_dfr(uncert_cube, function(tile) {
270-
# open spatial raster object
271-
rast <- .raster_open_rast(.tile_path(tile))
272-
# get the values
273-
values <- .raster_get_values(rast)
274-
# sample the maximum values
275-
samples_tile <- C_max_sampling(
276-
x = values,
277-
nrows = nrow(rast),
278-
ncols = ncol(rast),
279-
window_size = sampling_window
308+
# Create chunks as jobs
309+
chunks <- .tile_chunks_create(
310+
tile = tile,
311+
overlap = overlap,
312+
block = block
280313
)
281-
# get the top most values
282-
samples_tile <- samples_tile |>
283-
# randomly shuffle the rows of the dataset
284-
dplyr::sample_frac() |>
285-
dplyr::slice_max(
286-
.data[["value"]],
287-
n = n,
288-
with_ties = FALSE
314+
# Tile path
315+
tile_path <- .tile_path(tile)
316+
317+
# Process jobs in parallel
318+
chunk_results <- .jobs_map_parallel_dfr(chunks, function(chunk) {
319+
# Open tile images
320+
r_obj <- .raster_open_rast(tile_path)
321+
322+
# Get values for this chunk only
323+
values <- .raster_get_values(
324+
rast = r_obj,
325+
row = .block(chunk)[["row"]],
326+
col = .block(chunk)[["col"]],
327+
nrows = .block(chunk)[["nrows"]],
328+
ncols = .block(chunk)[["ncols"]]
329+
)
330+
331+
# Sample the maximum values in this chunk
332+
samples_chunk <- C_max_sampling(
333+
x = values,
334+
nrows = .block(chunk)[["nrows"]],
335+
ncols = .block(chunk)[["ncols"]],
336+
window_size = sampling_window
289337
)
290-
# transform to tibble
291-
tb <- rast |>
292-
.raster_xy_from_cell(
293-
cell = samples_tile[["cell"]]
294-
) |>
295-
tibble::as_tibble()
296-
# find NA
297-
na_rows <- which(is.na(tb))
298-
# remove NA
299-
if (.has(na_rows)) {
300-
tb <- tb[-na_rows, ]
301-
samples_tile <- samples_tile[-na_rows, ]
302-
}
303-
# Get the values' positions.
304-
result_tile <- tb |>
305-
sf::st_as_sf(
306-
coords = c("x", "y"),
307-
crs = .raster_crs(rast),
308-
dim = "XY",
309-
remove = TRUE
310-
) |>
311-
sf::st_transform(crs = "EPSG:4326") |>
312-
sf::st_coordinates()
313338

314-
colnames(result_tile) <- c("longitude", "latitude")
315-
result_tile <- result_tile |>
316-
dplyr::bind_cols(samples_tile) |>
317-
dplyr::mutate(
318-
value = .data[["value"]] *
319-
.conf("probs_cube_scale_factor")
320-
) |>
321-
dplyr::filter(
322-
.data[["value"]] >= min_uncert
323-
) |>
324-
dplyr::select(dplyr::matches(
325-
c("longitude", "latitude", "value")
326-
)) |>
327-
tibble::as_tibble()
339+
# Skip empty chunks
340+
if (nrow(samples_chunk) == 0) {
341+
return(tibble::tibble(
342+
longitude = numeric(0),
343+
latitude = numeric(0),
344+
value = numeric(0)
345+
))
346+
}
328347

329-
# All the cube's uncertainty images have the same start & end dates.
330-
result_tile[["start_date"]] <- .tile_start_date(uncert_cube)
331-
result_tile[["end_date"]] <- .tile_end_date(uncert_cube)
332-
result_tile[["label"]] <- "NoClass"
333-
result_tile
348+
# Create a virtual raster object for the chunk
349+
chunk_obj <- .chunks_as_raster(
350+
chunk = chunk,
351+
nlayers = 1L
352+
)
353+
354+
# transform to tibble using chunk coordinates
355+
tb <- chunk_obj |>
356+
.raster_xy_from_cell(
357+
cell = samples_chunk[["cell"]]
358+
) |>
359+
tibble::as_tibble()
360+
361+
# find NA
362+
na_rows <- which(is.na(tb))
363+
# remove NA
364+
if (.has(na_rows)) {
365+
tb <- tb[-na_rows, ]
366+
samples_chunk <- samples_chunk[-na_rows, ]
367+
}
368+
369+
# Skip if all NA
370+
if (nrow(tb) == 0) {
371+
return(tibble::tibble(
372+
longitude = numeric(0),
373+
latitude = numeric(0),
374+
value = numeric(0)
375+
))
376+
}
377+
378+
# Get the values' positions.
379+
result_chunk <- tb |>
380+
sf::st_as_sf(
381+
coords = c("x", "y"),
382+
crs = .raster_crs(chunk_obj),
383+
dim = "XY",
384+
remove = TRUE
385+
) |>
386+
sf::st_transform(crs = "EPSG:4326") |>
387+
sf::st_coordinates()
388+
389+
colnames(result_chunk) <- c("longitude", "latitude")
390+
result_chunk <- result_chunk |>
391+
dplyr::bind_cols(samples_chunk) |>
392+
dplyr::mutate(
393+
value = .data[["value"]] *
394+
.conf("probs_cube_scale_factor")
395+
) |>
396+
dplyr::filter(
397+
.data[["value"]] >= min_uncert,
398+
.data[["value"]] <= max_uncert
399+
) |>
400+
dplyr::select(dplyr::matches(
401+
c("longitude", "latitude", "value")
402+
)) |>
403+
tibble::as_tibble()
404+
405+
result_chunk
406+
}, progress = progress)
407+
408+
# Aggregate: select the top n values from all chunks in this tile
409+
if (nrow(chunk_results) > 0) {
410+
chunk_results |>
411+
# randomly shuffle the rows of the dataset
412+
dplyr::slice_sample(
413+
prop = 1
414+
) |>
415+
dplyr::slice_max(
416+
.data[["value"]],
417+
n = n,
418+
with_ties = FALSE
419+
) |>
420+
dplyr::mutate(
421+
start_date = .tile_start_date(uncert_cube),
422+
end_date = .tile_end_date(uncert_cube),
423+
label = "NoClass"
424+
)
425+
} else {
426+
# Return empty result if no samples found
427+
tibble::tibble(
428+
longitude = numeric(0),
429+
latitude = numeric(0),
430+
value = numeric(0),
431+
start_date = character(0),
432+
end_date = character(0),
433+
label = character(0)
434+
)
435+
}
334436
})
335437
renamed_cols <- c(uncertainty = "value")
336438
samples_tb <- dplyr::rename(samples_tb, dplyr::all_of(renamed_cols))

man/sits_uncertainty_sampling.Rd

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)