Skip to content

Commit 6d85d8b

Browse files
oestebanjhlegarreta
authored andcommitted
Fix MultiShellKernel type issues
1 parent eda3893 commit 6d85d8b

2 files changed

Lines changed: 13 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ module = [
155155
"joblib",
156156
"h5py",
157157
"ConfigSpace",
158+
"scipy.*",
159+
"sklearn.*",
160+
"skimage.*",
161+
"pandas",
162+
"attrs",
158163
]
159164
ignore_missing_imports = true
160165

src/nifreeze/model/gpr.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _constrained_optimization(
252252
options=options,
253253
args=(self.eval_gradient,),
254254
tol=self.tol,
255-
)
255+
) # type: ignore[call-overload]
256256
return opt_res.x, opt_res.fun
257257

258258
if callable(self.optimizer):
@@ -479,6 +479,9 @@ def __repr__(self) -> str:
479479
class MultiShellKernel(KernelOperator):
480480
"""Composite kernel for multi-shell diffusion data."""
481481

482+
k1: Kernel
483+
k2: Kernel
484+
482485
def __init__(
483486
self,
484487
orientation_kernel: Kernel | None = None,
@@ -511,10 +514,13 @@ def __call__(
511514
eval_gradient: bool = False,
512515
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
513516
X_o, X_b = self._split(X)
517+
Y_o: np.ndarray | None
518+
Y_b: np.ndarray | None
514519
if Y is not None:
515520
Y_o, Y_b = self._split(Y)
516521
else:
517-
Y_o = Y_b = None
522+
Y_o = None
523+
Y_b = None
518524

519525
if eval_gradient:
520526
K1, g1 = self.k1(X_o, Y_o, eval_gradient=True)

0 commit comments

Comments
 (0)