39
39
# ' selected edges across folds is returned. If `"all"`, the selected edges for
40
40
# ' each fold is returned, which is a 3D array and memory-consuming.
41
41
# ' @param na_action A character string indicating the action when missing values
42
- # ' are found in `behav`. If `"fail"`, an error will be thrown. If `"omit"`,
43
- # ' missing values will be omitted. If `"exclude"`, missing values will be
44
- # ' excluded from the analysis and added back to the output. Note `conmat` must
45
- # ' not contain any missing values, and `confounds` must not contain missing
46
- # ' values for complete cases of `behav`.
42
+ # ' are found in `behav`. If `"fail"`, an error will be thrown. If `"exclude"`,
43
+ # ' missing values will be excluded from the analysis but kept in the output.
44
+ # ' Note complete cases are intersection of `conmat`, `behav` and `confounds`
45
+ # ' if provided.
47
46
# ' @return A list with the following components:
48
47
# '
49
48
# ' \item{folds}{The corresponding fold for each observation when used as test
@@ -97,13 +96,14 @@ cpm <- function(conmat, behav, ...,
97
96
kfolds = NULL ,
98
97
bias_correct = TRUE ,
99
98
return_edges = c(" sum" , " none" , " all" ),
100
- na_action = c(" fail" , " omit " , " exclude" )) {
99
+ na_action = c(" fail" , " exclude" )) {
101
100
call <- match.call()
102
101
thresh_method <- match.arg(thresh_method )
103
102
return_edges <- match.arg(return_edges )
104
103
na_action <- match.arg(na_action )
105
- # ensure `behav` is a vector, name and length match
106
- behav <- drop(behav )
104
+
105
+ # check input data
106
+ behav <- drop(behav ) # convert to vector
107
107
if (! is.vector(behav ) || ! is.numeric(behav )) {
108
108
stop(" Behavior data must be a numeric vector." )
109
109
}
@@ -118,83 +118,92 @@ cpm <- function(conmat, behav, ...,
118
118
}
119
119
check_names(confounds , behav )
120
120
}
121
- # `conmat` cannot contain any missing values
122
- stopifnot(" Missing values are not allowed in `conmat`." = ! anyNA(conmat ))
123
- # handle missing values in `behav`
121
+
122
+ # handle missing cases
124
123
include_cases <- switch (na_action ,
125
124
fail = {
126
- stopifnot(" Missing values found in `behav`." = ! anyNA(behav ))
127
- rep(TRUE , length(behav ))
125
+ stopifnot(
126
+ " Missing values found in `conmat`" = ! anyNA(conmat ),
127
+ " Missing values found in `behav`" = ! anyNA(behav ),
128
+ " Missing values found in `confounds`" =
129
+ is.null(confounds ) || ! anyNA(confounds )
130
+ )
131
+ seq_along(behav )
128
132
},
129
- omit = ,
130
- exclude = ! is.na(behav )
133
+ exclude = Reduce(
134
+ function (x , y ) intersect(x , y ),
135
+ list (
136
+ which(complete.cases(conmat )),
137
+ which(complete.cases(behav )),
138
+ if (! is.null(confounds )) {
139
+ which(complete.cases(confounds ))
140
+ } else {
141
+ seq_along(behav )
142
+ }
143
+ )
144
+ )
131
145
)
132
- conmat_use <- conmat [ include_cases , , drop = FALSE ]
133
- behav_use <- behav [ include_cases ]
146
+
147
+ # confounds regression
134
148
if (! is.null(confounds )) {
135
- confounds_use <- confounds [include_cases , , drop = FALSE ]
136
- stopifnot(
137
- " Missing values found for used cases in `confounds`." =
138
- ! anyNA(confounds_use )
149
+ conmat [include_cases , ] <- regress_counfounds(
150
+ conmat [include_cases , , drop = FALSE ],
151
+ confounds [include_cases , , drop = FALSE ]
152
+ )
153
+ behav [include_cases ] <- regress_counfounds(
154
+ behav [include_cases ],
155
+ confounds [include_cases , , drop = FALSE ]
139
156
)
140
- conmat_use <- regress_counfounds(conmat_use , confounds_use )
141
- behav_use <- regress_counfounds(behav_use , confounds_use )
142
157
}
143
- # default to leave-one-subject-out
144
- if (is.null(kfolds )) kfolds <- length(behav_use )
145
- folds <- crossv_kfold(length(behav_use ), kfolds )
158
+
159
+ # prepare for cross-validation
160
+ if (is.null(kfolds )) kfolds <- length(include_cases ) # default to LOOCV
161
+ folds <- crossv_kfold(include_cases , kfolds )
162
+
146
163
# pre-allocation
147
164
edges <- switch (return_edges ,
148
165
all = array (
149
- dim = c(dim(conmat_use )[2 ], length(networks ), kfolds ),
166
+ dim = c(dim(conmat )[2 ], length(networks ), kfolds ),
150
167
dimnames = list (NULL , networks , NULL )
151
168
),
152
169
sum = array (
153
170
0 ,
154
- dim = c(dim(conmat_use )[2 ], length(networks )),
171
+ dim = c(dim(conmat )[2 ], length(networks )),
155
172
dimnames = list (NULL , networks )
156
173
)
157
174
)
158
175
pred <- matrix (
159
- nrow = length(behav_use ),
176
+ nrow = length(behav ),
160
177
ncol = length(includes ),
161
- dimnames = list (names(behav_use ), includes )
178
+ dimnames = list (names(behav ), includes )
162
179
)
180
+
181
+ # process each fold of CPM
163
182
for (fold in seq_len(kfolds )) {
164
- rows_train <- folds != fold
165
- conmat_train <- conmat_use [rows_train , , drop = FALSE ]
166
- behav_train <- behav_use [rows_train ]
183
+ rows_test <- folds [[fold ]]
184
+ rows_train <- setdiff(include_cases , rows_test )
185
+ conmat_train <- conmat [rows_train , , drop = FALSE ]
186
+ behav_train <- behav [rows_train ]
167
187
cur_edges <- select_edges(
168
188
conmat_train , behav_train ,
169
189
thresh_method , thresh_level
170
190
)
171
- conmat_test <- conmat_use [ ! rows_train , , drop = FALSE ]
191
+ conmat_test <- conmat [ rows_test , , drop = FALSE ]
172
192
cur_pred <- predict_cpm(
173
193
conmat_train , behav_train , conmat_test ,
174
194
cur_edges , bias_correct
175
195
)
176
- pred [! rows_train , ] <- cur_pred
196
+ pred [rows_test , ] <- cur_pred
177
197
if (return_edges == " all" ) {
178
198
edges [, , fold ] <- cur_edges
179
199
} else if (return_edges == " sum" ) {
180
200
edges <- edges + cur_edges
181
201
}
182
202
}
183
- # add back missing values when `na_action` is "exclude"
184
- if (na_action == " exclude" ) {
185
- behav_use <- behav
186
- pred_all <- matrix (
187
- nrow = length(behav ),
188
- ncol = length(includes ),
189
- dimnames = list (names(behav ), includes )
190
- )
191
- pred_all [include_cases , ] <- pred
192
- pred <- pred_all
193
- }
194
203
structure(
195
204
list (
196
205
folds = folds ,
197
- real = behav_use ,
206
+ real = behav ,
198
207
pred = pred ,
199
208
edges = edges ,
200
209
call = call ,
@@ -216,6 +225,7 @@ print.cpm <- function(x, ...) {
216
225
cat(" Call: " )
217
226
print(x $ call )
218
227
cat(sprintf(" Number of observations: %d\n " , length(x $ real )))
228
+ cat(sprintf(" Complete cases: %d\n " , sum(complete.cases(x $ pred ))))
219
229
if (! is.null(x $ edges )) {
220
230
cat(sprintf(" Number of edges: %d\n " , dim(x $ edges )[1 ]))
221
231
} else {
@@ -325,8 +335,8 @@ critical_r <- function(n, alpha) {
325
335
sqrt((ct ^ 2 ) / ((ct ^ 2 ) + df ))
326
336
}
327
337
328
- crossv_kfold <- function (n , k ) {
329
- sample(cut(seq_len( n ), breaks = k , labels = FALSE ))
338
+ crossv_kfold <- function (x , k ) {
339
+ split( sample(x ), cut(seq_along( x ), breaks = k , labels = FALSE ))
330
340
}
331
341
332
342
fscale <- function (x , center , scale ) {
0 commit comments