Skip to content

Commit 3b06afa

Browse files
committed
version change
1 parent e35fcdc commit 3b06afa

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FastaLoader"
22
uuid = "139838d8-4077-4d8a-94e1-e6dd554a184c"
33
authors = ["Shane Kuei Hsien Chu ([email protected])"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/fasta_w_splits.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,21 @@ end
147147

148148
function get_test_set_for_flux(fws::FASTA_DNA_w_splits; gpu=true)
149149
test_set_ind = get_test_set_ind(fws.mcs);
150-
return fws.data_matrix_gpu[:,:,test_set_ind], fws.label_indicators_gpu[test_set_ind,:]
150+
if gpu
151+
return fws.data_matrix_gpu[:,:,test_set_ind], fws.label_indicators_gpu[test_set_ind,:]
152+
else
153+
return fws.data_matrix[:,:,test_set_ind], fws.label_indicators[test_set_ind,:]
154+
end
151155
end
152156

153157
function get_train_fold_for_flux(fws::FASTA_DNA_w_splits, fold::Int; gpu=true)
154158
train_set_ind, valid_set_ind = get_train_fold_ind(fws.mcs, fold)
155-
return fws.data_matrix_gpu[:,:,train_set_ind], fws.label_indicators_gpu[valid_set_ind,:]
159+
if gpu
160+
return fws.data_matrix_gpu[:,:,train_set_ind], fws.label_indicators_gpu[train_set_ind,:]
161+
fws.data_matrix_gpu[:,:,valid_set_ind], fws.label_indicators_gpu[valid_set_ind,:]
162+
else
163+
return fws.data_matrix[:,:,train_set_ind], fws.label_indicators[train_set_ind,:]
164+
fws.data_matrix[:,:,valid_set_ind], fws.label_indicators[valid_set_ind,:]
165+
end
156166
end
157167

0 commit comments

Comments
 (0)