Skip to content

Commit ec7b5c4

Browse files
committed
replace probability fractions with NA in classification results
1 parent a5921ff commit ec7b5c4

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

R/api_classify.R

+2-5
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@
106106
# Should bbox of resulting tile be updated?
107107
update_bbox <- nrow(chunks) != nchunks
108108
}
109-
# Compute fractions probability
110-
probs_fractions <- 1 / length(.ml_labels(ml_model))
111109
# Process jobs in parallel
112110
block_files <- .jobs_map_parallel_chr(chunks, function(chunk) {
113111
# Job block
@@ -171,10 +169,9 @@
171169
scale <- .scale(band_conf)
172170
if (.has(scale) && scale != 1) {
173171
values <- values / scale
174-
probs_fractions <- probs_fractions / scale
175172
}
176-
# Mask NA pixels with same probabilities for all classes
177-
values[na_mask, ] <- probs_fractions
173+
# Put NA back in the result
174+
values[na_mask, ] <- NA
178175
# Log
179176
.debug_log(
180177
event = "start_block_data_save",

tests/testthat/test-classification.R

+74
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,77 @@ test_that("Classify error bands 1", {
5656
)
5757
)
5858
})
59+
60+
test_that("Classify with NA values", {
61+
# load cube
62+
data_dir <- system.file("extdata/raster/mod13q1", package = "sits")
63+
raster_cube <- sits_cube(
64+
source = "BDC",
65+
collection = "MOD13Q1-6.1",
66+
data_dir = data_dir,
67+
tiles = "012010",
68+
bands = "NDVI",
69+
start_date = "2013-09-14",
70+
end_date = "2014-08-29",
71+
multicores = 2,
72+
progress = FALSE
73+
)
74+
# preparation - create directory to save NA
75+
data_dir <- paste0(tempdir(), "/na-cube")
76+
dir.create(data_dir, recursive = TRUE, showWarnings = FALSE)
77+
# preparation - insert NA in cube
78+
raster_cube <- sits_apply(
79+
data = raster_cube,
80+
NDVI_NA = ifelse(NDVI > 0.5, NA, NDVI),
81+
output_dir = data_dir
82+
)
83+
raster_cube <- sits_select(raster_cube, bands = "NDVI_NA")
84+
.fi(raster_cube) <- .fi(raster_cube) |>
85+
dplyr::mutate(band = "NDVI")
86+
# preparation - create a random forest model
87+
rfor_model <- sits_train(samples_modis_ndvi, sits_rfor(num_trees = 40))
88+
# test classification with NA
89+
class_map <- sits_classify(
90+
data = raster_cube,
91+
ml_model = rfor_model,
92+
output_dir = tempdir(),
93+
progress = FALSE
94+
)
95+
class_map_rst <- terra::rast(class_map[["file_info"]][[1]][["path"]])
96+
expect_true(anyNA(class_map_rst[]))
97+
})
98+
99+
test_that("Classify with exclusion mask", {
100+
# load cube
101+
data_dir <- system.file("extdata/raster/mod13q1", package = "sits")
102+
raster_cube <- sits_cube(
103+
source = "BDC",
104+
collection = "MOD13Q1-6.1",
105+
data_dir = data_dir,
106+
tiles = "012010",
107+
bands = "NDVI",
108+
start_date = "2013-09-14",
109+
end_date = "2014-08-29",
110+
multicores = 2,
111+
progress = FALSE
112+
)
113+
# preparation - create a random forest model
114+
rfor_model <- sits_train(samples_modis_ndvi, sits_rfor(num_trees = 40))
115+
# test classification with NA
116+
class_map <- suppressWarnings(
117+
sits_classify(
118+
data = raster_cube,
119+
ml_model = rfor_model,
120+
output_dir = tempdir(),
121+
exclusion_mask = c(
122+
xmin = -55.63478,
123+
ymin = -11.63328,
124+
xmax = -55.54080,
125+
ymax = -11.56978
126+
),
127+
progress = FALSE
128+
)
129+
)
130+
class_map_rst <- terra::rast(class_map[["file_info"]][[1]][["path"]])
131+
expect_true(anyNA(class_map_rst[]))
132+
})

0 commit comments

Comments
 (0)