Skip to content

Commit

Permalink
Let na_action handle missing by intersections
Browse files Browse the repository at this point in the history
  • Loading branch information
psychelzh committed Oct 6, 2024
1 parent 9c15982 commit 58113cb
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 131 deletions.
108 changes: 59 additions & 49 deletions R/cpm.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@
#' selected edges across folds is returned. If `"all"`, the selected edges for
#' each fold is returned, which is a 3D array and memory-consuming.
#' @param na_action A character string indicating the action when missing values
#' are found in `behav`. If `"fail"`, an error will be thrown. If `"omit"`,
#' missing values will be omitted. If `"exclude"`, missing values will be
#' excluded from the analysis and added back to the output. Note `conmat` must
#' not contain any missing values, and `confounds` must not contain missing
#' values for complete cases of `behav`.
#' are found in `behav`. If `"fail"`, an error will be thrown. If `"exclude"`,
#' missing values will be excluded from the analysis but kept in the output.
#' Note complete cases are intersection of `conmat`, `behav` and `confounds`
#' if provided.
#' @return A list with the following components:
#'
#' \item{folds}{The corresponding fold for each observation when used as test
Expand Down Expand Up @@ -97,13 +96,14 @@ cpm <- function(conmat, behav, ...,
kfolds = NULL,
bias_correct = TRUE,
return_edges = c("sum", "none", "all"),
na_action = c("fail", "omit", "exclude")) {
na_action = c("fail", "exclude")) {
call <- match.call()
thresh_method <- match.arg(thresh_method)
return_edges <- match.arg(return_edges)
na_action <- match.arg(na_action)
# ensure `behav` is a vector, name and length match
behav <- drop(behav)

# check input data
behav <- drop(behav) # convert to vector
if (!is.vector(behav) || !is.numeric(behav)) {
stop("Behavior data must be a numeric vector.")
}
Expand All @@ -118,83 +118,92 @@ cpm <- function(conmat, behav, ...,
}
check_names(confounds, behav)
}
# `conmat` cannot contain any missing values
stopifnot("Missing values are not allowed in `conmat`." = !anyNA(conmat))
# handle missing values in `behav`

# handle missing cases
include_cases <- switch(na_action,
fail = {
stopifnot("Missing values found in `behav`." = !anyNA(behav))
rep(TRUE, length(behav))
stopifnot(
"Missing values found in `conmat`" = !anyNA(conmat),
"Missing values found in `behav`" = !anyNA(behav),
"Missing values found in `confounds`" =
is.null(confounds) || !anyNA(confounds)
)
seq_along(behav)
},
omit = ,
exclude = !is.na(behav)
exclude = Reduce(
function(x, y) intersect(x, y),
list(
which(complete.cases(conmat)),
which(complete.cases(behav)),
if (!is.null(confounds)) {
which(complete.cases(confounds))
} else {
seq_along(behav)
}
)
)
)
conmat_use <- conmat[include_cases, , drop = FALSE]
behav_use <- behav[include_cases]

# confounds regression
if (!is.null(confounds)) {
confounds_use <- confounds[include_cases, , drop = FALSE]
stopifnot(
"Missing values found for used cases in `confounds`." =
!anyNA(confounds_use)
conmat[include_cases, ] <- regress_counfounds(
conmat[include_cases, , drop = FALSE],
confounds[include_cases, , drop = FALSE]
)
behav[include_cases] <- regress_counfounds(
behav[include_cases],
confounds[include_cases, , drop = FALSE]
)
conmat_use <- regress_counfounds(conmat_use, confounds_use)
behav_use <- regress_counfounds(behav_use, confounds_use)
}
# default to leave-one-subject-out
if (is.null(kfolds)) kfolds <- length(behav_use)
folds <- crossv_kfold(length(behav_use), kfolds)

# prepare for cross-validation
if (is.null(kfolds)) kfolds <- length(include_cases) # default to LOOCV
folds <- crossv_kfold(include_cases, kfolds)

# pre-allocation
edges <- switch(return_edges,
all = array(
dim = c(dim(conmat_use)[2], length(networks), kfolds),
dim = c(dim(conmat)[2], length(networks), kfolds),
dimnames = list(NULL, networks, NULL)
),
sum = array(
0,
dim = c(dim(conmat_use)[2], length(networks)),
dim = c(dim(conmat)[2], length(networks)),
dimnames = list(NULL, networks)
)
)
pred <- matrix(
nrow = length(behav_use),
nrow = length(behav),
ncol = length(includes),
dimnames = list(names(behav_use), includes)
dimnames = list(names(behav), includes)
)

# process each fold of CPM
for (fold in seq_len(kfolds)) {
rows_train <- folds != fold
conmat_train <- conmat_use[rows_train, , drop = FALSE]
behav_train <- behav_use[rows_train]
rows_test <- folds[[fold]]
rows_train <- setdiff(include_cases, rows_test)
conmat_train <- conmat[rows_train, , drop = FALSE]
behav_train <- behav[rows_train]
cur_edges <- select_edges(
conmat_train, behav_train,
thresh_method, thresh_level
)
conmat_test <- conmat_use[!rows_train, , drop = FALSE]
conmat_test <- conmat[rows_test, , drop = FALSE]
cur_pred <- predict_cpm(
conmat_train, behav_train, conmat_test,
cur_edges, bias_correct
)
pred[!rows_train, ] <- cur_pred
pred[rows_test, ] <- cur_pred
if (return_edges == "all") {
edges[, , fold] <- cur_edges
} else if (return_edges == "sum") {
edges <- edges + cur_edges
}
}
# add back missing values when `na_action` is "exclude"
if (na_action == "exclude") {
behav_use <- behav
pred_all <- matrix(
nrow = length(behav),
ncol = length(includes),
dimnames = list(names(behav), includes)
)
pred_all[include_cases, ] <- pred
pred <- pred_all
}
structure(
list(
folds = folds,
real = behav_use,
real = behav,
pred = pred,
edges = edges,
call = call,
Expand All @@ -216,6 +225,7 @@ print.cpm <- function(x, ...) {
cat(" Call: ")
print(x$call)
cat(sprintf(" Number of observations: %d\n", length(x$real)))
cat(sprintf(" Complete cases: %d\n", sum(complete.cases(x$pred))))
if (!is.null(x$edges)) {
cat(sprintf(" Number of edges: %d\n", dim(x$edges)[1]))
} else {
Expand Down Expand Up @@ -325,8 +335,8 @@ critical_r <- function(n, alpha) {
sqrt((ct^2) / ((ct^2) + df))
}

crossv_kfold <- function(n, k) {
sample(cut(seq_len(n), breaks = k, labels = FALSE))
crossv_kfold <- function(x, k) {
split(sample(x), cut(seq_along(x), breaks = k, labels = FALSE))
}

fscale <- function(x, center, scale) {
Expand Down
11 changes: 5 additions & 6 deletions man/cpm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 58113cb

Please sign in to comment.