Skip to content

Commit 3a5aa79

Browse files
Merge pull request #1129 from M3nin0/feature/active-learning-multicores
active learning sampling: add multicores processing support
2 parents 71fae2d + a4d7211 commit 3a5aa79

5 files changed

+169
-51
lines changed

R/api_raster.R

+13-5
Original file line numberDiff line numberDiff line change
@@ -168,32 +168,40 @@
168168
#' locations are guaranteed to be separated by a certain number of pixels.
169169
#'
170170
#' @param r_obj A raster object.
171+
#' @param block Individual block that will be processed.
171172
#' @param band A numeric band index used to read bricks.
172173
#' @param n Number of values to extract.
173174
#' @param sampling_window Window size to collect a point (in pixels).
174175
#'
175176
#' @return A point `tibble` object.
176177
#'
177178
.raster_get_top_values <- function(r_obj,
179+
block,
178180
band,
179181
n,
180182
sampling_window) {
181183
# Pre-conditions have been checked in calling functions
182184
# Get top values
183185
# filter by median to avoid borders
184186
# Process window
185-
values <- .raster_get_values(r_obj)
187+
values <- .raster_get_values(
188+
r_obj,
189+
row = block[["row"]],
190+
col = block[["col"]],
191+
nrows = block[["nrows"]],
192+
ncols = block[["ncols"]]
193+
)
186194
values <- C_kernel_median(
187195
x = values,
188-
ncols = .raster_ncols(r_obj),
189-
nrows = .raster_nrows(r_obj),
196+
nrows = block[["nrows"]],
197+
ncols = block[["ncols"]],
190198
band = 0,
191199
window_size = sampling_window
192200
)
193201
samples_tb <- C_max_sampling(
194202
x = values,
195-
nrows = .raster_nrows(r_obj),
196-
ncols = .raster_ncols(r_obj),
203+
nrows = block[["nrows"]],
204+
ncols = block[["ncols"]],
197205
window_size = sampling_window
198206
)
199207
samples_tb <- dplyr::slice_max(

R/sits_active_learning.R

+129-41
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
#' @param min_uncert Minimum uncertainty value to select a sample.
3232
#' @param sampling_window Window size for collecting points (in pixels).
3333
#' The minimum window size is 10.
34+
#' @param multicores Number of workers for parallel processing
35+
#' (integer, min = 1, max = 2048).
36+
#' @param memsize Maximum overall memory (in GB) to run the
37+
#' function.
3438
#'
3539
#' @return
3640
#' A tibble with longitude and latitude in WGS84 with locations
@@ -75,23 +79,64 @@
7579
sits_uncertainty_sampling <- function(uncert_cube,
7680
n = 100L,
7781
min_uncert = 0.4,
78-
sampling_window = 10L) {
82+
sampling_window = 10L,
83+
multicores = 1L,
84+
memsize = 1L) {
7985
.check_set_caller("sits_uncertainty_sampling")
80-
8186
# Pre-conditions
8287
.check_is_uncert_cube(uncert_cube)
8388
.check_int_parameter(n, min = 1, max = 10000)
8489
.check_num_parameter(min_uncert, min = 0.2, max = 1.0)
8590
.check_int_parameter(sampling_window, min = 10L)
86-
91+
.check_int_parameter(multicores, min = 1, max = 2048)
92+
.check_int_parameter(memsize, min = 1, max = 16384)
93+
# Get block size
94+
block <- .raster_file_blocksize(.raster_open_rast(.tile_path(uncert_cube)))
95+
# Overlapping pixels
96+
overlap <- ceiling(sampling_window / 2) - 1
97+
# Check minimum memory needed to process one block
98+
job_memsize <- .jobs_memsize(
99+
job_size = .block_size(block = block, overlap = overlap),
100+
npaths = sampling_window,
101+
nbytes = 8,
102+
proc_bloat = .conf("processing_bloat_cpu")
103+
)
104+
# Update multicores parameter
105+
multicores <- .jobs_max_multicores(
106+
job_memsize = job_memsize,
107+
memsize = memsize,
108+
multicores = multicores
109+
)
110+
# Update block parameter
111+
block <- .jobs_optimal_block(
112+
job_memsize = job_memsize,
113+
block = block,
114+
image_size = .tile_size(.tile(uncert_cube)),
115+
memsize = memsize,
116+
multicores = multicores
117+
)
118+
# Prepare parallel processing
119+
.parallel_start(workers = multicores)
120+
on.exit(.parallel_stop(), add = TRUE)
87121
# Slide on cube tiles
88122
samples_tb <- slider::slide_dfr(uncert_cube, function(tile) {
89-
path <- .tile_path(tile)
123+
# Create chunks as jobs
124+
chunks <- .tile_chunks_create(
125+
tile = tile,
126+
overlap = overlap,
127+
block = block
128+
)
129+
# Tile path
130+
tile_path <- .tile_path(tile)
90131
# Get a list of values of high uncertainty
91-
top_values <- .raster_open_rast(path) |>
132+
# Process jobs in parallel
133+
top_values <- .jobs_map_parallel_dfr(chunks, function(chunk) {
134+
# Read and preprocess values
135+
.raster_open_rast(tile_path) |>
92136
.raster_get_top_values(
93-
band = 1,
94-
n = n,
137+
block = .block(chunk),
138+
band = 1,
139+
n = n,
95140
sampling_window = sampling_window
96141
) |>
97142
dplyr::mutate(
@@ -105,6 +150,7 @@ sits_uncertainty_sampling <- function(uncert_cube,
105150
c("longitude", "latitude", "value")
106151
)) |>
107152
tibble::as_tibble()
153+
})
108154
# All the cube's uncertainty images have the same start & end dates.
109155
top_values[["start_date"]] <- .tile_start_date(tile)
110156
top_values[["end_date"]] <- .tile_end_date(tile)
@@ -174,6 +220,10 @@ sits_uncertainty_sampling <- function(uncert_cube,
174220
#' @param min_margin Minimum margin of confidence to select a sample
175221
#' @param sampling_window Window size for collecting points (in pixels).
176222
#' The minimum window size is 10.
223+
#' @param multicores Number of workers for parallel processing
224+
#' (integer, min = 1, max = 2048).
225+
#' @param memsize Maximum overall memory (in GB) to run the
226+
#' function.
177227
#'
178228
#' @return
179229
#' A tibble with longitude and latitude in WGS84 with locations
@@ -204,54 +254,92 @@ sits_uncertainty_sampling <- function(uncert_cube,
204254
sits_confidence_sampling <- function(probs_cube,
205255
n = 20L,
206256
min_margin = 0.90,
207-
sampling_window = 10L) {
257+
sampling_window = 10L,
258+
multicores = 1L,
259+
memsize = 1L) {
208260
.check_set_caller("sits_confidence_sampling")
209-
210261
# Pre-conditions
211262
.check_is_probs_cube(probs_cube)
212263
.check_int_parameter(n, min = 20)
213264
.check_num_parameter(min_margin, min = 0.01, max = 1.0)
214265
.check_int_parameter(sampling_window, min = 10)
215-
266+
.check_int_parameter(multicores, min = 1, max = 2048)
267+
.check_int_parameter(memsize, min = 1, max = 16384)
268+
# Get block size
269+
block <- .raster_file_blocksize(.raster_open_rast(.tile_path(probs_cube)))
270+
# Overlapping pixels
271+
overlap <- ceiling(sampling_window / 2) - 1
272+
# Check minimum memory needed to process one block
273+
job_memsize <- .jobs_memsize(
274+
job_size = .block_size(block = block, overlap = overlap),
275+
npaths = sampling_window,
276+
nbytes = 8,
277+
proc_bloat = .conf("processing_bloat_cpu")
278+
)
279+
# Update multicores parameter
280+
multicores <- .jobs_max_multicores(
281+
job_memsize = job_memsize,
282+
memsize = memsize,
283+
multicores = multicores
284+
)
285+
# Update block parameter
286+
block <- .jobs_optimal_block(
287+
job_memsize = job_memsize,
288+
block = block,
289+
image_size = .tile_size(.tile(probs_cube)),
290+
memsize = memsize,
291+
multicores = multicores
292+
)
293+
# Prepare parallel processing
294+
.parallel_start(workers = multicores)
295+
on.exit(.parallel_stop(), add = TRUE)
216296
# get labels
217297
labels <- sits_labels(probs_cube)
218-
219298
# Slide on cube tiles
220299
samples_tb <- slider::slide_dfr(probs_cube, function(tile) {
221-
# Open raster
222-
r_obj <- .raster_open_rast(.tile_path(tile))
223-
224-
# Get samples for each label
225-
purrr::map2_dfr(labels, seq_along(labels), function(lab, i) {
226-
# Get a list of values of high confidence & apply threshold
227-
top_values <- r_obj |>
228-
.raster_get_top_values(
229-
band = i,
230-
n = n,
231-
sampling_window = sampling_window
232-
) |>
233-
dplyr::mutate(
234-
value = .data[["value"]] *
235-
.conf("probs_cube_scale_factor")
236-
) |>
237-
dplyr::filter(
238-
.data[["value"]] >= min_margin
239-
) |>
240-
dplyr::select(dplyr::matches(
241-
c("longitude", "latitude", "value")
242-
)) |>
243-
tibble::as_tibble()
300+
# Create chunks as jobs
301+
chunks <- .tile_chunks_create(
302+
tile = tile,
303+
overlap = overlap,
304+
block = block
305+
)
306+
# Tile path
307+
tile_path <- .tile_path(tile)
308+
# Get a list of values of high uncertainty
309+
# Process jobs in parallel
310+
.jobs_map_parallel_dfr(chunks, function(chunk) {
311+
# Get samples for each label
312+
purrr::map2_dfr(labels, seq_along(labels), function(lab, i) {
313+
# Get a list of values of high confidence & apply threshold
314+
top_values <- .raster_open_rast(tile_path) |>
315+
.raster_get_top_values(
316+
block = .block(chunk),
317+
band = i,
318+
n = n,
319+
sampling_window = sampling_window
320+
) |>
321+
dplyr::mutate(
322+
value = .data[["value"]] *
323+
.conf("probs_cube_scale_factor")
324+
) |>
325+
dplyr::filter(
326+
.data[["value"]] >= min_margin
327+
) |>
328+
dplyr::select(dplyr::matches(
329+
c("longitude", "latitude", "value")
330+
)) |>
331+
tibble::as_tibble()
244332

245-
# All the cube's uncertainty images have the same start &
246-
# end dates.
247-
top_values[["start_date"]] <- .tile_start_date(tile)
248-
top_values[["end_date"]] <- .tile_end_date(tile)
249-
top_values[["label"]] <- lab
333+
# All the cube's uncertainty images have the same start &
334+
# end dates.
335+
top_values[["start_date"]] <- .tile_start_date(tile)
336+
top_values[["end_date"]] <- .tile_end_date(tile)
337+
top_values[["label"]] <- lab
250338

251-
return(top_values)
339+
return(top_values)
340+
})
252341
})
253342
})
254-
255343
# Slice result samples
256344
result_tb <- samples_tb |>
257345
dplyr::group_by(.data[["label"]]) |>

man/sits_confidence_sampling.Rd

+9-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sits_uncertainty_sampling.Rd

+9-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-active_learning.R

+9-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ test_that("Suggested samples have low confidence, high entropy", {
3434
uncert_cube,
3535
min_uncert = 0.3,
3636
n = 100,
37-
sampling_window = 10
37+
sampling_window = 10,
38+
multicores = 2,
39+
memsize = 2
3840
))
3941

4042
expect_true(nrow(samples_df) <= 100)
@@ -80,15 +82,19 @@ test_that("Increased samples have high confidence, low entropy", {
8082
probs_cube = probs_cube,
8183
n = 20,
8284
min_margin = 0.5,
83-
sampling_window = 10
85+
sampling_window = 10,
86+
multicores = 2,
87+
memsize = 2
8488
)
8589
)
8690
expect_warning(
8791
sits_confidence_sampling(
8892
probs_cube = probs_cube,
8993
n = 60,
9094
min_margin = 0.5,
91-
sampling_window = 10
95+
sampling_window = 10,
96+
multicores = 2,
97+
memsize = 2
9298
)
9399
)
94100
labels <- sits_labels(probs_cube)

0 commit comments

Comments
 (0)