Skip to content

Commit 58113cb

Browse files
committed
Let na_action handle missing by intersections
1 parent 9c15982 commit 58113cb

File tree

4 files changed

+130
-131
lines changed

4 files changed

+130
-131
lines changed

R/cpm.R

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,10 @@
3939
#' selected edges across folds is returned. If `"all"`, the selected edges for
4040
#' each fold is returned, which is a 3D array and memory-consuming.
4141
#' @param na_action A character string indicating the action when missing values
42-
#' are found in `behav`. If `"fail"`, an error will be thrown. If `"omit"`,
43-
#' missing values will be omitted. If `"exclude"`, missing values will be
44-
#' excluded from the analysis and added back to the output. Note `conmat` must
45-
#' not contain any missing values, and `confounds` must not contain missing
46-
#' values for complete cases of `behav`.
42+
#' are found in `behav`. If `"fail"`, an error will be thrown. If `"exclude"`,
43+
#' missing values will be excluded from the analysis but kept in the output.
44+
#' Note complete cases are intersection of `conmat`, `behav` and `confounds`
45+
#' if provided.
4746
#' @return A list with the following components:
4847
#'
4948
#' \item{folds}{The corresponding fold for each observation when used as test
@@ -97,13 +96,14 @@ cpm <- function(conmat, behav, ...,
9796
kfolds = NULL,
9897
bias_correct = TRUE,
9998
return_edges = c("sum", "none", "all"),
100-
na_action = c("fail", "omit", "exclude")) {
99+
na_action = c("fail", "exclude")) {
101100
call <- match.call()
102101
thresh_method <- match.arg(thresh_method)
103102
return_edges <- match.arg(return_edges)
104103
na_action <- match.arg(na_action)
105-
# ensure `behav` is a vector, name and length match
106-
behav <- drop(behav)
104+
105+
# check input data
106+
behav <- drop(behav) # convert to vector
107107
if (!is.vector(behav) || !is.numeric(behav)) {
108108
stop("Behavior data must be a numeric vector.")
109109
}
@@ -118,83 +118,92 @@ cpm <- function(conmat, behav, ...,
118118
}
119119
check_names(confounds, behav)
120120
}
121-
# `conmat` cannot contain any missing values
122-
stopifnot("Missing values are not allowed in `conmat`." = !anyNA(conmat))
123-
# handle missing values in `behav`
121+
122+
# handle missing cases
124123
include_cases <- switch(na_action,
125124
fail = {
126-
stopifnot("Missing values found in `behav`." = !anyNA(behav))
127-
rep(TRUE, length(behav))
125+
stopifnot(
126+
"Missing values found in `conmat`" = !anyNA(conmat),
127+
"Missing values found in `behav`" = !anyNA(behav),
128+
"Missing values found in `confounds`" =
129+
is.null(confounds) || !anyNA(confounds)
130+
)
131+
seq_along(behav)
128132
},
129-
omit = ,
130-
exclude = !is.na(behav)
133+
exclude = Reduce(
134+
function(x, y) intersect(x, y),
135+
list(
136+
which(complete.cases(conmat)),
137+
which(complete.cases(behav)),
138+
if (!is.null(confounds)) {
139+
which(complete.cases(confounds))
140+
} else {
141+
seq_along(behav)
142+
}
143+
)
144+
)
131145
)
132-
conmat_use <- conmat[include_cases, , drop = FALSE]
133-
behav_use <- behav[include_cases]
146+
147+
# confounds regression
134148
if (!is.null(confounds)) {
135-
confounds_use <- confounds[include_cases, , drop = FALSE]
136-
stopifnot(
137-
"Missing values found for used cases in `confounds`." =
138-
!anyNA(confounds_use)
149+
conmat[include_cases, ] <- regress_counfounds(
150+
conmat[include_cases, , drop = FALSE],
151+
confounds[include_cases, , drop = FALSE]
152+
)
153+
behav[include_cases] <- regress_counfounds(
154+
behav[include_cases],
155+
confounds[include_cases, , drop = FALSE]
139156
)
140-
conmat_use <- regress_counfounds(conmat_use, confounds_use)
141-
behav_use <- regress_counfounds(behav_use, confounds_use)
142157
}
143-
# default to leave-one-subject-out
144-
if (is.null(kfolds)) kfolds <- length(behav_use)
145-
folds <- crossv_kfold(length(behav_use), kfolds)
158+
159+
# prepare for cross-validation
160+
if (is.null(kfolds)) kfolds <- length(include_cases) # default to LOOCV
161+
folds <- crossv_kfold(include_cases, kfolds)
162+
146163
# pre-allocation
147164
edges <- switch(return_edges,
148165
all = array(
149-
dim = c(dim(conmat_use)[2], length(networks), kfolds),
166+
dim = c(dim(conmat)[2], length(networks), kfolds),
150167
dimnames = list(NULL, networks, NULL)
151168
),
152169
sum = array(
153170
0,
154-
dim = c(dim(conmat_use)[2], length(networks)),
171+
dim = c(dim(conmat)[2], length(networks)),
155172
dimnames = list(NULL, networks)
156173
)
157174
)
158175
pred <- matrix(
159-
nrow = length(behav_use),
176+
nrow = length(behav),
160177
ncol = length(includes),
161-
dimnames = list(names(behav_use), includes)
178+
dimnames = list(names(behav), includes)
162179
)
180+
181+
# process each fold of CPM
163182
for (fold in seq_len(kfolds)) {
164-
rows_train <- folds != fold
165-
conmat_train <- conmat_use[rows_train, , drop = FALSE]
166-
behav_train <- behav_use[rows_train]
183+
rows_test <- folds[[fold]]
184+
rows_train <- setdiff(include_cases, rows_test)
185+
conmat_train <- conmat[rows_train, , drop = FALSE]
186+
behav_train <- behav[rows_train]
167187
cur_edges <- select_edges(
168188
conmat_train, behav_train,
169189
thresh_method, thresh_level
170190
)
171-
conmat_test <- conmat_use[!rows_train, , drop = FALSE]
191+
conmat_test <- conmat[rows_test, , drop = FALSE]
172192
cur_pred <- predict_cpm(
173193
conmat_train, behav_train, conmat_test,
174194
cur_edges, bias_correct
175195
)
176-
pred[!rows_train, ] <- cur_pred
196+
pred[rows_test, ] <- cur_pred
177197
if (return_edges == "all") {
178198
edges[, , fold] <- cur_edges
179199
} else if (return_edges == "sum") {
180200
edges <- edges + cur_edges
181201
}
182202
}
183-
# add back missing values when `na_action` is "exclude"
184-
if (na_action == "exclude") {
185-
behav_use <- behav
186-
pred_all <- matrix(
187-
nrow = length(behav),
188-
ncol = length(includes),
189-
dimnames = list(names(behav), includes)
190-
)
191-
pred_all[include_cases, ] <- pred
192-
pred <- pred_all
193-
}
194203
structure(
195204
list(
196205
folds = folds,
197-
real = behav_use,
206+
real = behav,
198207
pred = pred,
199208
edges = edges,
200209
call = call,
@@ -216,6 +225,7 @@ print.cpm <- function(x, ...) {
216225
cat(" Call: ")
217226
print(x$call)
218227
cat(sprintf(" Number of observations: %d\n", length(x$real)))
228+
cat(sprintf(" Complete cases: %d\n", sum(complete.cases(x$pred))))
219229
if (!is.null(x$edges)) {
220230
cat(sprintf(" Number of edges: %d\n", dim(x$edges)[1]))
221231
} else {
@@ -325,8 +335,8 @@ critical_r <- function(n, alpha) {
325335
sqrt((ct^2) / ((ct^2) + df))
326336
}
327337

328-
crossv_kfold <- function(n, k) {
329-
sample(cut(seq_len(n), breaks = k, labels = FALSE))
338+
crossv_kfold <- function(x, k) {
339+
split(sample(x), cut(seq_along(x), breaks = k, labels = FALSE))
330340
}
331341

332342
fscale <- function(x, center, scale) {

man/cpm.Rd

Lines changed: 5 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)