Skip to content

Commit

Permalink
typ: Annotate Kernel.diag() argument X
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Jan 22, 2025
1 parent c766f7b commit bd38772
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/nifreeze/model/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Callable, ClassVar, Literal, Mapping, Optional, Sequence, Union

import numpy as np
import numpy.typing as npt
from scipy import optimize
from scipy.optimize import Bounds
from sklearn.gaussian_process import GaussianProcessRegressor
Expand Down Expand Up @@ -334,7 +335,7 @@ def __call__(

return self.beta_l * C_theta, K_gradient

def diag(self, X) -> np.ndarray:
def diag(self, X: npt.ArrayLike) -> np.ndarray:
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
Expand All @@ -351,7 +352,7 @@ def diag(self, X) -> np.ndarray:
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.beta_l * np.ones(X.shape[0])
return self.beta_l * np.ones(np.asanyarray(X).shape[0])

def is_stationary(self) -> bool:
"""Returns whether the kernel is stationary."""
Expand Down Expand Up @@ -444,7 +445,7 @@ def __call__(

return self.beta_l * C_theta, K_gradient

def diag(self, X) -> np.ndarray:
def diag(self, X: npt.ArrayLike) -> np.ndarray:
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
Expand All @@ -461,7 +462,7 @@ def diag(self, X) -> np.ndarray:
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.beta_l * np.ones(X.shape[0])
return self.beta_l * np.ones(np.asanyarray(X).shape[0])

def is_stationary(self) -> bool:
"""Returns whether the kernel is stationary."""
Expand Down

0 comments on commit bd38772

Please sign in to comment.