diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 3766df5d..fa44e250 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -5,6 +5,16 @@ What's new in the package ========================= + + +Develop branch +---------------- + +- Fix :class:`pyriemann_qiskit.utils.filtering.ChannelSelection` that caused incorrect channel selection + + + + v0.4.0 ------ diff --git a/pyriemann_qiskit/utils/filtering.py b/pyriemann_qiskit/utils/filtering.py index 4fe58f7e..afcb72d3 100644 --- a/pyriemann_qiskit/utils/filtering.py +++ b/pyriemann_qiskit/utils/filtering.py @@ -139,13 +139,6 @@ def __init__(self, n_channels, cov_est="lwf"): self.n_channels = n_channels self.cov_est = cov_est - @staticmethod - def _get_indices(maxes, mean_cov): - indices = [] - for v in maxes: - indices.extend(np.argwhere(mean_cov == v).flatten()) - return np.unique(indices) - def fit(self, X, y=None, **kwargs): """Select channel based on covariances @@ -166,19 +159,10 @@ def fit(self, X, y=None, **kwargs): covs = Covariances(estimator=self.cov_est).fit_transform(X) # Get the average covariance between the channels. mean_cov = np.mean(covs, axis=0) - n_feats, _ = mean_cov.shape # Select the `n_channels` channels having the maximum covariances. - all_max = [] - for i in range(n_feats): - for j in range(i, n_feats): - self._chs_idx = ChannelSelection._get_indices(all_max, mean_cov) - - if len(self._chs_idx) < self.n_channels: - all_max.append(mean_cov[i, j]) - else: - if mean_cov[i, j] > max(all_max): - all_max[np.argmin(all_max)] = mean_cov[i, j] - + self._chs_idx = np.argpartition( + np.max(mean_cov, axis=0), -self.n_channels, axis=None + )[-self.n_channels :] return self def transform(self, X, **kwargs):