@@ -49,7 +49,7 @@ def cross_validate(
49
49
cv : int ,
50
50
n_repeats : int ,
51
51
gpr : DiffusionGPR ,
52
- ) -> dict [ int , list [ tuple [ np .ndarray , np . ndarray , np . ndarray , np . ndarray ]]] :
52
+ ) -> np .ndarray :
53
53
"""
54
54
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
55
55
@@ -211,10 +211,10 @@ def main() -> None:
211
211
212
212
if args .kfold :
213
213
# Use Scikit-learn cross validation
214
- scores = defaultdict (list , {})
214
+ scores : dict [ str , list ] = defaultdict (list , {})
215
215
for n in args .kfold :
216
216
for i in range (args .repeats ):
217
- cv_scores = - 1.0 * cross_validate (X , y .T , n , gpr )
217
+ cv_scores = - 1.0 * cross_validate (X , y .T , n , i , gpr )
218
218
scores ["rmse" ] += cv_scores .tolist ()
219
219
scores ["repeat" ] += [i ] * len (cv_scores )
220
220
scores ["n_folds" ] += [n ] * len (cv_scores )
@@ -224,7 +224,7 @@ def main() -> None:
224
224
print (f"Finished { n } -fold cross-validation" )
225
225
226
226
scores_df = pd .DataFrame (scores )
227
- scores_df .to_csv (args .output_scores , sep = "\t " , index = None , na_rep = "n/a" )
227
+ scores_df .to_csv (args .output_scores , sep = "\t " , index = False , na_rep = "n/a" )
228
228
229
229
grouped = scores_df .groupby (["n_folds" ])
230
230
print (grouped [["rmse" ]].mean ())
0 commit comments