Skip to content

Commit a655de5

Browse files
authored
[MRG] MNT Fixes for PCA with n_components='mle' (scikit-learn#16841)
* Fixed off by one in MLE and better handling of small eigenvalues * light update tests * pep8 * Added test + threhsold on small log
1 parent 270c673 commit a655de5

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

doc/whats_new/v0.23.rst

+6-4
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ Changelog
139139
- |Fix| :class:`decomposition.PCA` with a float `n_components` parameter, will
140140
exclusively choose the components that explain the variance greater than
141141
`n_components`. :pr:`15669` by :user:`Krishna Chaitanya <krishnachaitanya9>`
142-
- |Fix| :func:`decomposition._pca._assess_dimension` now correctly handles small
143-
eigenvalues. :pr: `4441` by :user:`Lisa Schwetlick <lschwetlick>`, and
144-
:user:`Gelavizh Ahmadi <gelavizh1>` and
145-
:user:`Marija Vlajic Wheeler <marijavlajic>`.
142+
143+
- |Fix| :class:`decomposition.PCA` with `n_components='mle'` now correctly
144+
handles small eigenvalues, and does not infer 0 as the correct number of
145+
components. :pr: `4441` by :user:`Lisa Schwetlick <lschwetlick>`, and
146+
:user:`Gelavizh Ahmadi <gelavizh1>` and :user:`Marija Vlajic Wheeler
147+
<marijavlajic>` and :pr:`16841` by `Nicolas Hug`_.
146148

147149
- |Enhancement| :class:`decomposition.NMF` and
148150
:func:`decomposition.non_negative_factorization` now preserves float32 dtype.

sklearn/decomposition/_pca.py

+33-39
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,22 @@
2828
from ..utils.validation import _deprecate_positional_args
2929

3030

31-
def _assess_dimension(spectrum, rank, n_samples, n_features):
32-
"""Compute the likelihood of a rank ``rank`` dataset.
31+
def _assess_dimension(spectrum, rank, n_samples):
32+
"""Compute the log-likelihood of a rank ``rank`` dataset.
3333
3434
The dataset is assumed to be embedded in gaussian noise of shape(n,
3535
dimf) having spectrum ``spectrum``.
3636
3737
Parameters
3838
----------
39-
spectrum : array of shape (n)
39+
spectrum : array of shape (n_features)
4040
Data spectrum.
4141
rank : int
42-
Tested rank value.
42+
Tested rank value. It should be strictly lower than n_features,
43+
otherwise the method isn't specified (division by zero in equation
44+
(31) from the paper).
4345
n_samples : int
4446
Number of samples.
45-
n_features : int
46-
Number of features.
4747
4848
Returns
4949
-------
@@ -55,45 +55,39 @@ def _assess_dimension(spectrum, rank, n_samples, n_features):
5555
This implements the method of `Thomas P. Minka:
5656
Automatic Choice of Dimensionality for PCA. NIPS 2000: 598-604`
5757
"""
58-
if rank > len(spectrum):
59-
raise ValueError("The tested rank cannot exceed the rank of the"
60-
" dataset")
6158

62-
spectrum_threshold = np.finfo(type(spectrum[0])).eps
59+
n_features = spectrum.shape[0]
60+
if not 1 <= rank < n_features:
61+
raise ValueError("the tested rank should be in [1, n_features - 1]")
62+
63+
eps = 1e-15
64+
65+
if spectrum[rank - 1] < eps:
66+
# When the tested rank is associated with a small eigenvalue, there's
67+
# no point in computing the log-likelihood: it's going to be very
68+
# small and won't be the max anyway. Also, it can lead to numerical
69+
# issues below when computing pa, in particular in log((spectrum[i] -
70+
# spectrum[j]) because this will take the log of something very small.
71+
return -np.inf
6372

6473
pu = -rank * log(2.)
65-
for i in range(rank):
66-
pu += (gammaln((n_features - i) / 2.) -
67-
log(np.pi) * (n_features - i) / 2.)
74+
for i in range(1, rank + 1):
75+
pu += (gammaln((n_features - i + 1) / 2.) -
76+
log(np.pi) * (n_features - i + 1) / 2.)
6877

6978
pl = np.sum(np.log(spectrum[:rank]))
7079
pl = -pl * n_samples / 2.
7180

72-
if rank == n_features:
73-
# TODO: this line is never executed because _infer_dimension's
74-
# for loop is off by one
75-
pv = 0
76-
v = 1
77-
else:
78-
v = np.sum(spectrum[rank:]) / (n_features - rank)
79-
if spectrum_threshold > v:
80-
return -np.inf
81-
pv = -np.log(v) * n_samples * (n_features - rank) / 2.
81+
v = max(eps, np.sum(spectrum[rank:]) / (n_features - rank))
82+
pv = -np.log(v) * n_samples * (n_features - rank) / 2.
8283

8384
m = n_features * rank - rank * (rank + 1.) / 2.
84-
pp = log(2. * np.pi) * (m + rank + 1.) / 2.
85+
pp = log(2. * np.pi) * (m + rank) / 2.
8586

8687
pa = 0.
8788
spectrum_ = spectrum.copy()
8889
spectrum_[rank:n_features] = v
8990
for i in range(rank):
90-
if spectrum_[i] < spectrum_threshold:
91-
# TODO: this line is never executed
92-
# (off by one in _infer_dimension)
93-
# this break only happens when rank == n_features and
94-
# spectrum_[i] < spectrum_threshold, otherwise the early return
95-
# above catches this case.
96-
break
9791
for j in range(i + 1, len(spectrum)):
9892
pa += log((spectrum[i] - spectrum[j]) *
9993
(1. / spectrum_[j] - 1. / spectrum_[i])) + log(n_samples)
@@ -103,15 +97,15 @@ def _assess_dimension(spectrum, rank, n_samples, n_features):
10397
return ll
10498

10599

106-
def _infer_dimension(spectrum, n_samples, n_features):
107-
"""Infers the dimension of a dataset of shape (n_samples, n_features)
100+
def _infer_dimension(spectrum, n_samples):
101+
"""Infers the dimension of a dataset with a given spectrum.
108102
109-
The dataset is described by its spectrum `spectrum`.
103+
The returned value will be in [1, n_features - 1].
110104
"""
111-
n_spectrum = len(spectrum)
112-
ll = np.empty(n_spectrum)
113-
for rank in range(n_spectrum):
114-
ll[rank] = _assess_dimension(spectrum, rank, n_samples, n_features)
105+
ll = np.empty_like(spectrum)
106+
ll[0] = -np.inf # we don't want to return n_components = 0
107+
for rank in range(1, spectrum.shape[0]):
108+
ll[rank] = _assess_dimension(spectrum, rank, n_samples)
115109
return ll.argmax()
116110

117111

@@ -472,7 +466,7 @@ def _fit_full(self, X, n_components):
472466
# Postprocess the number of components required
473467
if n_components == 'mle':
474468
n_components = \
475-
_infer_dimension(explained_variance_, n_samples, n_features)
469+
_infer_dimension(explained_variance_, n_samples)
476470
elif 0 < n_components < 1.0:
477471
# number of components for which the cumulated explained
478472
# variance percentage is superior to the desired threshold

sklearn/decomposition/tests/test_pca.py

+48-33
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def test_n_components_mle(svd_solver):
295295
X = rng.randn(n_samples, n_features)
296296
pca = PCA(n_components='mle', svd_solver=svd_solver)
297297
pca.fit(X)
298-
assert pca.n_components_ == 0
298+
assert pca.n_components_ == 1
299299

300300

301301
@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"])
@@ -333,7 +333,7 @@ def test_infer_dim_1():
333333
pca = PCA(n_components=p, svd_solver='full')
334334
pca.fit(X)
335335
spect = pca.explained_variance_
336-
ll = np.array([_assess_dimension(spect, k, n, p) for k in range(p)])
336+
ll = np.array([_assess_dimension(spect, k, n) for k in range(1, p)])
337337
assert ll[1] > ll.max() - .01 * n
338338

339339

@@ -348,7 +348,7 @@ def test_infer_dim_2():
348348
pca = PCA(n_components=p, svd_solver='full')
349349
pca.fit(X)
350350
spect = pca.explained_variance_
351-
assert _infer_dimension(spect, n, p) > 1
351+
assert _infer_dimension(spect, n) > 1
352352

353353

354354
def test_infer_dim_3():
@@ -361,7 +361,7 @@ def test_infer_dim_3():
361361
pca = PCA(n_components=p, svd_solver='full')
362362
pca.fit(X)
363363
spect = pca.explained_variance_
364-
assert _infer_dimension(spect, n, p) > 2
364+
assert _infer_dimension(spect, n) > 2
365365

366366

367367
@pytest.mark.parametrize(
@@ -570,51 +570,43 @@ def test_pca_n_components_mostly_explained_variance_ratio():
570570
assert pca2.n_components_ == X.shape[1]
571571

572572

573-
def test_infer_dim_bad_spec():
574-
# Test a spectrum that drops to near zero for PR #16224
573+
def test_assess_dimension_bad_rank():
574+
# Test error when tested rank not in [1, n_features - 1]
575575
spectrum = np.array([1, 1e-30, 1e-30, 1e-30])
576576
n_samples = 10
577-
n_features = 5
578-
ret = _infer_dimension(spectrum, n_samples, n_features)
579-
assert ret == 0
577+
for rank in (0, 5):
578+
with pytest.raises(ValueError,
579+
match=r"should be in \[1, n_features - 1\]"):
580+
_assess_dimension(spectrum, rank, n_samples)
580581

581582

582-
def test_assess_dimension_error_rank_greater_than_features():
583-
# Test error when tested rank is greater than the number of features
584-
# for PR #16224
583+
def test_small_eigenvalues_mle():
584+
# Test rank associated with tiny eigenvalues are given a log-likelihood of
585+
# -inf. The inferred rank will be 1
585586
spectrum = np.array([1, 1e-30, 1e-30, 1e-30])
586-
n_samples = 10
587-
n_features = 4
588-
rank = 5
589-
with pytest.raises(ValueError, match="The tested rank cannot exceed "
590-
"the rank of the dataset"):
591-
_assess_dimension(spectrum, rank, n_samples, n_features)
592587

588+
assert _assess_dimension(spectrum, rank=1, n_samples=10) > -np.inf
593589

594-
def test_assess_dimension_small_eigenvalues():
595-
# Test tiny eigenvalues appropriately when using 'mle'
596-
# for PR #16224
597-
spectrum = np.array([1, 1e-30, 1e-30, 1e-30])
598-
n_samples = 10
599-
n_features = 5
600-
rank = 3
601-
ret = _assess_dimension(spectrum, rank, n_samples, n_features)
602-
assert ret == -np.inf
590+
for rank in (2, 3):
591+
assert _assess_dimension(spectrum, rank, 10) == -np.inf
592+
593+
assert _infer_dimension(spectrum, 10) == 1
603594

604595

605-
def test_infer_dim_mle():
606-
# Test small eigenvalues when 'mle' with pathological 'X' dataset
607-
# for PR #16224
608-
X, _ = datasets.make_classification(n_informative=1, n_repeated=18,
596+
def test_mle_redundant_data():
597+
# Test 'mle' with pathological X: only one relevant feature should give a
598+
# rank of 1
599+
X, _ = datasets.make_classification(n_features=20,
600+
n_informative=1, n_repeated=18,
609601
n_redundant=1, n_clusters_per_class=1,
610602
random_state=42)
611603
pca = PCA(n_components='mle').fit(X)
612-
assert pca.n_components_ == 0
604+
assert pca.n_components_ == 1
613605

614606

615607
def test_fit_mle_too_few_samples():
616608
# Tests that an error is raised when the number of samples is smaller
617-
# than the number of features during an mle fit for PR #16224
609+
# than the number of features during an mle fit
618610
X, _ = datasets.make_classification(n_samples=20, n_features=21,
619611
random_state=42)
620612

@@ -623,3 +615,26 @@ def test_fit_mle_too_few_samples():
623615
"supported if "
624616
"n_samples >= n_features"):
625617
pca.fit(X)
618+
619+
620+
def test_mle_simple_case():
621+
# non-regression test for issue
622+
# https://github.com/scikit-learn/scikit-learn/issues/16730
623+
n_samples, n_dim = 1000, 10
624+
X = np.random.RandomState(0).randn(n_samples, n_dim)
625+
X[:, -1] = np.mean(X[:, :-1], axis=-1) # true X dim is ndim - 1
626+
pca_skl = PCA('mle', svd_solver='full')
627+
pca_skl.fit(X)
628+
assert pca_skl.n_components_ == n_dim - 1
629+
630+
631+
def test_assess_dimesion_rank_one():
632+
# Make sure assess_dimension works properly on a matrix of rank 1
633+
n_samples, n_features = 9, 6
634+
X = np.ones((n_samples, n_features)) # rank 1 matrix
635+
_, s, _ = np.linalg.svd(X, full_matrices=True)
636+
assert sum(s[1:]) == 0 # except for rank 1, all eigenvalues are 0
637+
638+
assert np.isfinite(_assess_dimension(s, rank=1, n_samples=n_samples))
639+
for rank in range(2, n_features):
640+
assert _assess_dimension(s, rank, n_samples) == -np.inf

0 commit comments

Comments
 (0)