@@ -147,11 +147,21 @@ end
147
147
148
148
function get_test_set_for_flux (fws:: FASTA_DNA_w_splits ; gpu= true )
149
149
test_set_ind = get_test_set_ind (fws. mcs);
150
- return fws. data_matrix_gpu[:,:,test_set_ind], fws. label_indicators_gpu[test_set_ind,:]
150
+ if gpu
151
+ return fws. data_matrix_gpu[:,:,test_set_ind], fws. label_indicators_gpu[test_set_ind,:]
152
+ else
153
+ return fws. data_matrix[:,:,test_set_ind], fws. label_indicators[test_set_ind,:]
154
+ end
151
155
end
152
156
153
157
function get_train_fold_for_flux (fws:: FASTA_DNA_w_splits , fold:: Int ; gpu= true )
154
158
train_set_ind, valid_set_ind = get_train_fold_ind (fws. mcs, fold)
155
- return fws. data_matrix_gpu[:,:,train_set_ind], fws. label_indicators_gpu[valid_set_ind,:]
159
+ if gpu
160
+ return fws. data_matrix_gpu[:,:,train_set_ind], fws. label_indicators_gpu[train_set_ind,:]
161
+ fws. data_matrix_gpu[:,:,valid_set_ind], fws. label_indicators_gpu[valid_set_ind,:]
162
+ else
163
+ return fws. data_matrix[:,:,train_set_ind], fws. label_indicators[train_set_ind,:]
164
+ fws. data_matrix[:,:,valid_set_ind], fws. label_indicators[valid_set_ind,:]
165
+ end
156
166
end
157
167
0 commit comments