@@ -74,13 +74,15 @@ function read_ryan_fasta(all_labels::Vector{String};
74
74
class_indices = [findall (@view class_indicators[:,i]) for i = 1 : size (class_indicators,2 )]
75
75
shuffles_class_indices = shuffle_this ? shuffle .(class_indices) : class_indices;
76
76
# note that class_indicators is not shuffled
77
- return shuffles_class_indices, class_indicators
77
+ # TODO : remove allocations later
78
+ return shuffles_class_indices, Array (class_indicators' )
78
79
end
79
80
80
81
function make_FASTA_DNA_w_splits (fp:: String ;
81
82
class_selector= read_ryan_fasta,
82
83
split_ratio= 0.85 ,
83
84
folds= 5 ,
85
+ flux= true ,
84
86
float_type= Float32)
85
87
all_labels, all_dna_read = get_ryan_fasta_str_labels (fp);
86
88
shuffles_class_indices, class_indicators = class_selector (all_labels);
@@ -93,7 +95,7 @@ function make_FASTA_DNA_w_splits(fp::String;
93
95
);
94
96
data_matrix, data_matrix_bg, _, acgt_freq, markov_bg_mat = FastaLoader. get_data_matrices (all_dna_read;
95
97
FloatType= float_type);
96
- return FASTA_DNA_w_splits (mcs,
98
+ fws = FASTA_DNA_w_splits (mcs,
97
99
all_labels,
98
100
class_indicators,
99
101
cu (class_indicators),
@@ -104,6 +106,7 @@ function make_FASTA_DNA_w_splits(fp::String;
104
106
cu (data_matrix),
105
107
data_matrix_bg
106
108
);
109
+ flux && fasta_reshape_for_flux! (fws);
107
110
end
108
111
109
112
function get_test_set_ind (mcs:: multiple_class_splits )
@@ -148,20 +151,20 @@ end
148
151
function get_test_set_for_flux (fws:: FASTA_DNA_w_splits ; gpu= true )
149
152
test_set_ind = get_test_set_ind (fws. mcs);
150
153
if gpu
151
- return fws. data_matrix_gpu[:,:,test_set_ind], fws. label_indicators_gpu[test_set_ind,: ]
154
+ return fws. data_matrix_gpu[:,:,test_set_ind], fws. label_indicators_gpu[:,test_set_ind ]
152
155
else
153
- return fws. data_matrix[:,:,test_set_ind], fws. label_indicators[test_set_ind,: ]
156
+ return fws. data_matrix[:,:,test_set_ind], fws. label_indicators[:,test_set_ind ]
154
157
end
155
158
end
156
159
157
160
function get_train_fold_for_flux (fws:: FASTA_DNA_w_splits , fold:: Int ; gpu= true )
158
161
train_set_ind, valid_set_ind = get_train_fold_ind (fws. mcs, fold)
159
162
if gpu
160
- return fws. data_matrix_gpu[:,:,train_set_ind], fws. label_indicators_gpu[train_set_ind,: ]
161
- fws. data_matrix_gpu[:,:,valid_set_ind], fws. label_indicators_gpu[valid_set_ind,: ]
163
+ return fws. data_matrix_gpu[:,:,train_set_ind], fws. label_indicators_gpu[:,train_set_ind ]
164
+ fws. data_matrix_gpu[:,:,valid_set_ind], fws. label_indicators_gpu[:,valid_set_ind ]
162
165
else
163
- return fws. data_matrix[:,:,train_set_ind], fws. label_indicators[train_set_ind,: ]
164
- fws. data_matrix[:,:,valid_set_ind], fws. label_indicators[valid_set_ind,: ]
166
+ return fws. data_matrix[:,:,train_set_ind], fws. label_indicators[:,train_set_ind ]
167
+ fws. data_matrix[:,:,valid_set_ind], fws. label_indicators[:,valid_set_ind ]
165
168
end
166
169
end
167
170
0 commit comments