Skip to content

Commit

Permalink
type: Annotate Kernel.diag() argument X
Browse files Browse the repository at this point in the history
Annotate `Kernel.diag()` argument `X`: use `npt.ArrayLike`.

Fixes:
```
src/nifreeze/model/gpr.py:335: error:
 Argument 1 of "diag" is incompatible with supertype "Kernel";
 supertype defines the argument type as
 "Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"
  [override]
src/nifreeze/model/gpr.py:335: note: This violates the Liskov substitution principle
src/nifreeze/model/gpr.py:335: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
src/nifreeze/model/gpr.py:445: error:
 Argument 1 of "diag" is incompatible with supertype "Kernel";
 supertype defines the argument type as
 "Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"
  [override]
src/nifreeze/model/gpr.py:445: note: This violates the Liskov substitution principle
src/nifreeze/model/gpr.py:335: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
```

raised for example in:
https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:93

Documentation:
https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
  • Loading branch information
effigies authored and jhlegarreta committed Jan 23, 2025
1 parent 26ec410 commit e03f7f2
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/nifreeze/model/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Callable, ClassVar, Literal, Mapping, Optional, Sequence, Union

import numpy as np
import numpy.typing as npt
from scipy import optimize
from scipy.optimize import Bounds
from sklearn.gaussian_process import GaussianProcessRegressor
Expand Down Expand Up @@ -334,7 +335,7 @@ def __call__(

return self.beta_l * C_theta, K_gradient

def diag(self, X: np.ndarray) -> np.ndarray:
def diag(self, X: npt.ArrayLike) -> np.ndarray:
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
Expand All @@ -351,7 +352,7 @@ def diag(self, X: np.ndarray) -> np.ndarray:
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.beta_l * np.ones(X.shape[0])
return self.beta_l * np.ones(np.asanyarray(X).shape[0])

def is_stationary(self) -> bool:
"""Returns whether the kernel is stationary."""
Expand Down Expand Up @@ -444,7 +445,7 @@ def __call__(

return self.beta_l * C_theta, K_gradient

def diag(self, X: np.ndarray) -> np.ndarray:
def diag(self, X: npt.ArrayLike) -> np.ndarray:
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
Expand All @@ -461,7 +462,7 @@ def diag(self, X: np.ndarray) -> np.ndarray:
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.beta_l * np.ones(X.shape[0])
return self.beta_l * np.ones(np.asanyarray(X).shape[0])

def is_stationary(self) -> bool:
"""Returns whether the kernel is stationary."""
Expand Down

0 comments on commit e03f7f2

Please sign in to comment.