Skip to content

Commit

Permalink
Catch some edge-case in heuristic_in_bin()
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Jan 1, 2024
1 parent a5f2fac commit 27a43d4
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 26 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## `sv_dependence()`: Control over automatic color feature selection

### How is the color feature selected anyway?
### How is the color feature selected, anyway?

If no SHAP interaction values are available, by default, the color feature `v'` is selected by the heuristic `potential_interaction()`, which works as follows:

Expand Down
42 changes: 20 additions & 22 deletions R/potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) {
if (isTRUE(color_num)) {
color <- .as_numeric(color)
}
color <- split(color, bins)
s <- split(s, bins)
M <- mapply(
heuristic_in_bin,
color = color,
s = s,
color = split(color, bins),
s = split(s, bins),
MoreArgs = list(scale = scale, adjusted = adjusted)
)
stats::weighted.mean(M[1L, ], M[2L, ], na.rm = TRUE)
Expand All @@ -112,24 +110,24 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) {
#' @returns
#' A (1x2) matrix with heuristic and number of observations.
heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) {
suppressWarnings(
tryCatch(
{
z <- stats::lm(s ~ color)
r <- z$residuals
n <- length(r)
var_y <- stats::var(z$fitted.values + r)
denom <- if (adjusted) z$df.residual else n - 1
var_r <- sum(r^2) / denom
stat <- 1 - var_r / var_y
if (scale) {
stat <- stat * var_y
}
cbind(stat = stat, n = n)
},
error = function(e) return(cbind(stat = NA, n = 0))
)
)
ok <- !is.na(color)
color <- color[ok]
s <- s[ok]
n <- length(s)
var_s <- stats::var(s)
if (n < 2L || var_s < .Machine$double.eps || length(unique(color)) < 2L) {
return(cbind(stat = NA, n = n))
}
z <- stats::lm(s ~ color)
var_r <- sum(z$residuals^2) / (if (adjusted) z$df.residual else n - 1)
stat <- 1 - var_r / var_s
if (scale) {
stat <- stat * var_s
}
if (!is.finite(stat)) {
stat <- NA
}
cbind(stat = stat, n = n)
}

# Like as.numeric(), but can deal with factor variables
Expand Down
4 changes: 2 additions & 2 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
scale = ih_scale,
adjusted = ih_adjusted
)
# 'scores' can be NULL, or a numeric vector like c(0.1, 0, -0.01, NaN, NA)
# Thus, let's take the first positive one (or none)
# 'scores' can be NULL, or a sorted vector like c(0.1, 0, -0.01, NA)
# Thus, let's take the first positive one (or NULL)
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
color_var <- if (length(scores) >= 1L) names(scores)[1L]
}
Expand Down
99 changes: 98 additions & 1 deletion tests/testthat/test-potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,104 @@ test_that("heuristic_in_bin() returns R-squared", {
})

test_that("Failing heuristic_in_bin() returns NA", {
expect_equal(heuristic_in_bin(0, 1:2), cbind(stat = NA, n = 0))
expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = NA, n = 0))
})

test_that("heuristic_in_bin() returns NA for constant response", {
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic_in_bin() returns NA for constant color", {
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic_in_bin() returns 0 if response and color are constant", {
z <- c(1, 1)
expect_equal(
heuristic_in_bin(color = z, s = z),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, scale = TRUE),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 2L)
)
})

test_that("heuristic_in_bin() returns NA for single obs", {
expect_equal(
heuristic_in_bin(color = 2, s = 2),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, scale = TRUE),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 1L)
)
})

test_that("heuristic_in_bin() returns NA for single obs", {
cc <- factor(LETTERS[1:3])
expect_equal(
heuristic_in_bin(color = cc, s = 1:3),
cbind(stat = 1, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 2*(1:3), scale = TRUE),
cbind(stat = 4, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 1:3, adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 2*(1:3), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic() returns average R-squared", {
Expand Down

0 comments on commit 27a43d4

Please sign in to comment.