Skip to content

Commit fb56aca

Browse files
effigiesjhlegarreta
authored andcommitted
type: Fixes
1 parent beb6eef commit fb56aca

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

scripts/dwi_gp_estimation_error_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def cross_validate(
4949
cv: int,
5050
n_repeats: int,
5151
gpr: DiffusionGPR,
52-
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
52+
) -> np.ndarray:
5353
"""
5454
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
5555
@@ -211,10 +211,10 @@ def main() -> None:
211211

212212
if args.kfold:
213213
# Use Scikit-learn cross validation
214-
scores = defaultdict(list, {})
214+
scores: dict[str, list] = defaultdict(list, {})
215215
for n in args.kfold:
216216
for i in range(args.repeats):
217-
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
217+
cv_scores = -1.0 * cross_validate(X, y.T, n, i, gpr)
218218
scores["rmse"] += cv_scores.tolist()
219219
scores["repeat"] += [i] * len(cv_scores)
220220
scores["n_folds"] += [n] * len(cv_scores)
@@ -224,7 +224,7 @@ def main() -> None:
224224
print(f"Finished {n}-fold cross-validation")
225225

226226
scores_df = pd.DataFrame(scores)
227-
scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a")
227+
scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a")
228228

229229
grouped = scores_df.groupby(["n_folds"])
230230
print(grouped[["rmse"]].mean())

src/nifreeze/registration/ants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def generate_command(
213213
movingmask_path: str | Path | list[str] | None = None,
214214
init_affine: str | Path | None = None,
215215
default: str = "b0-to-b0_level0",
216-
**kwargs: dict,
216+
**kwargs,
217217
) -> str:
218218
"""
219219
Generate an ANTs' command line.

0 commit comments

Comments
 (0)