Skip to content

Commit 9c15982

Browse files
committed
Merge branch 'main' of https://github.com/psychelzh/cpmr
2 parents 22dd9e2 + 686ef9c commit 9c15982

File tree

5 files changed

+154
-15
lines changed

5 files changed

+154
-15
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
* Added `summary()` method to summarize the results of the CPM analysis.
66
* Added `tidy()` method to tidy the results of the CPM analysis.
7+
* Support `na_action` argument in `cpm()` function to handle missing values in the input data (#2).
78

89
## Enhancements
910

R/cpm.R

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
#' selected edges. If `"none"`, no edges are returned. If `"sum"`, the sum of
3939
#' selected edges across folds is returned. If `"all"`, the selected edges for
4040
#' each fold is returned, which is a 3D array and memory-consuming.
41+
#' @param na_action A character string indicating the action when missing values
42+
#' are found in `behav`. If `"fail"`, an error will be thrown. If `"omit"`,
43+
#' missing values will be omitted. If `"exclude"`, missing values will be
44+
#' excluded from the analysis and added back to the output. Note `conmat` must
45+
#' not contain any missing values, and `confounds` must not contain missing
46+
#' values for complete cases of `behav`.
4147
#' @return A list with the following components:
4248
#'
4349
#' \item{folds}{The corresponding fold for each observation when used as test
@@ -90,11 +96,13 @@ cpm <- function(conmat, behav, ...,
9096
thresh_level = 0.01,
9197
kfolds = NULL,
9298
bias_correct = TRUE,
93-
return_edges = c("sum", "none", "all")) {
99+
return_edges = c("sum", "none", "all"),
100+
na_action = c("fail", "omit", "exclude")) {
94101
call <- match.call()
95102
thresh_method <- match.arg(thresh_method)
96103
return_edges <- match.arg(return_edges)
97-
# ensure `behav` is a vector
104+
na_action <- match.arg(na_action)
105+
# ensure `behav` is a vector, name and length match
98106
behav <- drop(behav)
99107
if (!is.vector(behav) || !is.numeric(behav)) {
100108
stop("Behavior data must be a numeric vector.")
@@ -109,38 +117,58 @@ cpm <- function(conmat, behav, ...,
109117
stop("Case numbers of `confounds` and `behav` must match.")
110118
}
111119
check_names(confounds, behav)
112-
conmat <- regress_counfounds(conmat, confounds)
113-
behav <- regress_counfounds(behav, confounds)
120+
}
121+
# `conmat` cannot contain any missing values
122+
stopifnot("Missing values are not allowed in `conmat`." = !anyNA(conmat))
123+
# handle missing values in `behav`
124+
include_cases <- switch(na_action,
125+
fail = {
126+
stopifnot("Missing values found in `behav`." = !anyNA(behav))
127+
rep(TRUE, length(behav))
128+
},
129+
omit = ,
130+
exclude = !is.na(behav)
131+
)
132+
conmat_use <- conmat[include_cases, , drop = FALSE]
133+
behav_use <- behav[include_cases]
134+
if (!is.null(confounds)) {
135+
confounds_use <- confounds[include_cases, , drop = FALSE]
136+
stopifnot(
137+
"Missing values found for used cases in `confounds`." =
138+
!anyNA(confounds_use)
139+
)
140+
conmat_use <- regress_counfounds(conmat_use, confounds_use)
141+
behav_use <- regress_counfounds(behav_use, confounds_use)
114142
}
115143
# default to leave-one-subject-out
116-
if (is.null(kfolds)) kfolds <- length(behav)
117-
folds <- crossv_kfold(length(behav), kfolds)
144+
if (is.null(kfolds)) kfolds <- length(behav_use)
145+
folds <- crossv_kfold(length(behav_use), kfolds)
118146
# pre-allocation
119147
edges <- switch(return_edges,
120148
all = array(
121-
dim = c(dim(conmat)[2], length(networks), kfolds),
149+
dim = c(dim(conmat_use)[2], length(networks), kfolds),
122150
dimnames = list(NULL, networks, NULL)
123151
),
124152
sum = array(
125153
0,
126-
dim = c(dim(conmat)[2], length(networks)),
154+
dim = c(dim(conmat_use)[2], length(networks)),
127155
dimnames = list(NULL, networks)
128156
)
129157
)
130158
pred <- matrix(
131-
nrow = length(behav),
159+
nrow = length(behav_use),
132160
ncol = length(includes),
133-
dimnames = list(names(behav), includes)
161+
dimnames = list(names(behav_use), includes)
134162
)
135163
for (fold in seq_len(kfolds)) {
136164
rows_train <- folds != fold
137-
conmat_train <- conmat[rows_train, , drop = FALSE]
138-
behav_train <- behav[rows_train]
165+
conmat_train <- conmat_use[rows_train, , drop = FALSE]
166+
behav_train <- behav_use[rows_train]
139167
cur_edges <- select_edges(
140168
conmat_train, behav_train,
141169
thresh_method, thresh_level
142170
)
143-
conmat_test <- conmat[!rows_train, , drop = FALSE]
171+
conmat_test <- conmat_use[!rows_train, , drop = FALSE]
144172
cur_pred <- predict_cpm(
145173
conmat_train, behav_train, conmat_test,
146174
cur_edges, bias_correct
@@ -152,10 +180,21 @@ cpm <- function(conmat, behav, ...,
152180
edges <- edges + cur_edges
153181
}
154182
}
183+
# add back missing values when `na_action` is "exclude"
184+
if (na_action == "exclude") {
185+
behav_use <- behav
186+
pred_all <- matrix(
187+
nrow = length(behav),
188+
ncol = length(includes),
189+
dimnames = list(names(behav), includes)
190+
)
191+
pred_all[include_cases, ] <- pred
192+
pred <- pred_all
193+
}
155194
structure(
156195
list(
157196
folds = folds,
158-
real = behav,
197+
real = behav_use,
159198
pred = pred,
160199
edges = edges,
161200
call = call,

man/cpm.Rd

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/cpm.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,75 @@
621621
CV folds: 10
622622
Bias correction: TRUE
623623

624+
# `na_action` argument works
625+
626+
{
627+
"type": "double",
628+
"attributes": {},
629+
"value": [0.25688371, -0.24669188, -0.3475426, -0.95161857, -0.04502772, -0.78490447, -1.66794194, -0.38022652, 0.91899661]
630+
}
631+
632+
---
633+
634+
{
635+
"type": "double",
636+
"attributes": {
637+
"dim": {
638+
"type": "integer",
639+
"attributes": {},
640+
"value": [9, 3]
641+
},
642+
"dimnames": {
643+
"type": "list",
644+
"attributes": {},
645+
"value": [
646+
{
647+
"type": "NULL"
648+
},
649+
{
650+
"type": "character",
651+
"attributes": {},
652+
"value": ["both", "pos", "neg"]
653+
}
654+
]
655+
}
656+
},
657+
"value": [-0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375, -0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375, -0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375]
658+
}
659+
660+
---
661+
662+
{
663+
"type": "double",
664+
"attributes": {},
665+
"value": ["NA", 0.25688371, -0.24669188, -0.3475426, -0.95161857, -0.04502772, -0.78490447, -1.66794194, -0.38022652, 0.91899661]
666+
}
667+
668+
---
669+
670+
{
671+
"type": "double",
672+
"attributes": {
673+
"dim": {
674+
"type": "integer",
675+
"attributes": {},
676+
"value": [10, 3]
677+
},
678+
"dimnames": {
679+
"type": "list",
680+
"attributes": {},
681+
"value": [
682+
{
683+
"type": "NULL"
684+
},
685+
{
686+
"type": "character",
687+
"attributes": {},
688+
"value": ["both", "pos", "neg"]
689+
}
690+
]
691+
}
692+
},
693+
"value": ["NA", -0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375, "NA", -0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375, "NA", -0.43811964, -0.37517269, -0.36256635, -0.28705685, -0.40038071, -0.30789611, -0.19751643, -0.35848086, -0.52088375]
694+
}
695+

tests/testthat/test-cpm.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,22 @@ test_that("Throw informative error if data checking not pass", {
119119
"Case numbers of `confounds` and `behav` must match."
120120
)
121121
})
122+
123+
test_that("`na_action` argument works", {
124+
withr::local_seed(123)
125+
conmat <- matrix(rnorm(100), ncol = 10)
126+
behav <- rnorm(10)
127+
behav[1] <- NA
128+
expect_error(cpm(conmat, behav), "Missing values found in `behav`.")
129+
result <- cpm(conmat, behav, na_action = "omit")
130+
expect_snapshot_value(result$real, style = "json2")
131+
expect_snapshot_value(result$pred, style = "json2")
132+
result <- cpm(conmat, behav, na_action = "exclude")
133+
expect_snapshot_value(result$real, style = "json2")
134+
expect_snapshot_value(result$pred, style = "json2")
135+
conmat[1, 1] <- NA
136+
expect_error(
137+
cpm(conmat, behav),
138+
"Missing values are not allowed in `conmat`."
139+
)
140+
})

0 commit comments

Comments
 (0)