Skip to content

Commit 859ce5e

Browse files
TimotheeMathieurth
andauthored
Enhance Huber robust mean estimator (#121)
* add stopping criterion and add test for huber * add c is None iqr heuristic and correct heuristic in robust_weighted_estimator * Apply suggestions from code review Co-authored-by: Roman Yurchak <[email protected]> * change name c and update doc * change forgotten name c * Revert "change forgotten name c" This reverts commit c5a59ac. * Revert "change name c and update doc" This reverts commit 8dd0cf9. * change c_ to c_numeric * add chegelog Co-authored-by: Roman Yurchak <[email protected]>
1 parent 75dbaa4 commit 859ce5e

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

doc/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Changelog
44
Unreleased
55
----------
66

7+
- Add a stopping criterion and parameter tuning heuristic for Huber robust mean
8+
estimator.
79
- Add `CLARA` (Clustering for Large Applications) which extends k-medoids to
810
be more scalable using a sampling approach.
911
[`#83 <https://github.com/scikit-learn-contrib/scikit-learn-extra/pull/83>`_].

sklearn_extra/robust/mean_estimators.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def median_of_means(X, k, random_state=np.random.RandomState(42)):
8888
return median_of_means_blocked(x, blocks)[0]
8989

9090

91-
def huber(X, c=1.35, T=20):
91+
def huber(X, c=None, T=20, tol=1e-3):
9292
"""Compute the Huber estimator of location of X with parameter c
9393
9494
Parameters
@@ -97,14 +97,19 @@ def huber(X, c=1.35, T=20):
9797
X : array like, length = n_sample
9898
sample from which we want an estimator of the mean
9999
100-
c : float >0, default = 1.35
100+
c : float >0, default = None
101101
parameter that control the robustness of the estimator.
102102
c going to zero gives a behavior close to the median.
103103
c going to infinity gives a behavior close to sample mean.
104+
if c is None, the interquartile range (IQR) is used
105+
as heuristic.
104106
105107
T : int, default = 20
106108
Number of iterations of the algorithm.
107109
110+
tol : float, default=1e-3
111+
Tolerance on stopping criterion.
112+
108113
Return
109114
------
110115
@@ -116,23 +121,38 @@ def huber(X, c=1.35, T=20):
116121
# Initialize the algorithm with a robust first-guess : the median.
117122
mu = np.median(x)
118123

124+
if c is None:
125+
c_numeric = iqr(x)
126+
else:
127+
c_numeric = c
128+
119129
def psisx(x, c):
120130
# Huber weight function.
121131
res = np.zeros(len(x))
122-
mask = np.abs(x) <= c
132+
mask = np.abs(x) <= c_numeric
123133
res[mask] = 1
124-
res[~mask] = c / np.abs(x[~mask])
134+
res[~mask] = c_numeric / np.abs(x[~mask])
125135
return res
126136

137+
# Create a list to keep the ten last values of mu
138+
last_mu = mu
139+
127140
# Run the iterative reweighting algorithm to compute M-estimator.
128141
for t in range(T):
129142
# Compute the weights
130-
w = psisx(x - mu, c)
143+
w = psisx(x - mu, c_numeric)
131144

132145
# Infinite coordinates in x gives zero weight, we take them out.
133146
ind_pos = w > 0
134147

135148
# Update the value of the estimate with the new estimate using the
136149
# new weights.
137150
mu = np.sum(np.array(w[ind_pos]) * x[ind_pos]) / np.sum(w[ind_pos])
151+
152+
# Stopping criterion. The error is decreasing at each iteration
153+
if np.abs(mu - last_mu) < tol:
154+
break
155+
else:
156+
last_mu = mu
157+
138158
return mu

sklearn_extra/robust/robust_weighted_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def _get_weights(self, loss_values, random_state):
460460
if self.weighting == "huber":
461461
if self.c is None:
462462
# If no c parameter given, estimate using inter quartile range.
463-
c = iqr(np.abs(loss_values - np.median(loss_values))) / 2
463+
c = iqr(loss_values) / 2
464464
if c == 0:
465465
warnings.warn(
466466
"Too many samples are parfectly predicted "

sklearn_extra/robust/tests/test_mean_estimators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ def test_mom():
2727
def test_huber():
2828
X = np.hstack([np.zeros(90), np.ones(10)])
2929
with pytest.warns(None) as record:
30-
huber(X)
30+
mu = huber(X, c=0.5)
3131
assert len(record) == 0
32+
assert np.abs(mu) < 0.1

0 commit comments

Comments
 (0)