Skip to content

Commit ff5a6e3

Browse files
Merge pull request #106 from antoinedemathelin/master
feat: RegularTransfer Gaussian Process
2 parents 324930c + 62384af commit ff5a6e3

9 files changed

+380
-7
lines changed

adapt/instance_based/_kliep.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self,
221221
max_centers=100,
222222
cv=5,
223223
algo="FW",
224-
lr=np.logspace(-3,1,5),
224+
lr=[0.001, 0.01, 0.1, 1.0, 10.0],
225225
tol=1e-6,
226226
max_iter=2000,
227227
copy=True,

adapt/parameter_based/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Parameter-Based Methods Module
33
"""
44

5-
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN
5+
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN, RegularTransferGP
66
from ._finetuning import FineTuning
77
from ._transfer_tree import TransferTreeClassifier
88
from ._transfer_tree import TransferForestClassifier
@@ -13,6 +13,7 @@
1313
__all__ = ["RegularTransferLR",
1414
"RegularTransferLC",
1515
"RegularTransferNN",
16+
"RegularTransferGP",
1617
"FineTuning",
1718
"TransferTreeClassifier",
1819
"TransferForestClassifier",

adapt/parameter_based/_regular.py

+183-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from sklearn.preprocessing import LabelBinarizer
77
from scipy.sparse.linalg import lsqr
8+
from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
9+
from sklearn.linear_model import LinearRegression
810
import tensorflow as tf
911
from tensorflow.keras import Sequential
1012
from tensorflow.keras.layers import Flatten, Dense
@@ -186,7 +188,8 @@ def fit(self, Xt=None, yt=None, **fit_params):
186188

187189
if yt_ndim_below_one_:
188190
self.coef_ = self.coef_.reshape(-1)
189-
self.intercept_ = self.intercept_[0]
191+
if self.estimator_.fit_intercept:
192+
self.intercept_ = self.intercept_[0]
190193

191194
self.estimator_.coef_ = self.coef_
192195
if self.estimator_.fit_intercept:
@@ -267,7 +270,11 @@ def fit(self, Xt=None, yt=None, **fit_params):
267270
Xt, yt = check_arrays(Xt, yt)
268271

269272
_label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
270-
yt = _label_binarizer.fit_transform(yt)
273+
_label_binarizer.fit(self.estimator.classes_)
274+
yt = _label_binarizer.transform(yt)
275+
276+
print(yt.shape)
277+
271278
return super().fit(Xt, yt, **fit_params)
272279

273280

@@ -467,4 +474,177 @@ def predict_disc(self, X):
467474
"""
468475
Not used.
469476
"""
470-
pass
477+
pass
478+
479+
480+
@make_insert_doc(supervised=True)
481+
class RegularTransferGP(BaseAdaptEstimator):
482+
"""
483+
Regular Transfer with Gaussian Process
484+
485+
RegularTransferGP is a parameter-based domain adaptation method.
486+
487+
The method is based on the assumption that a good target estimator
488+
can be obtained by adapting the parameters of a pre-trained source
489+
estimator using a few labeled target data.
490+
491+
The approach consist in fitting the `alpha` coeficients of a
492+
Gaussian Process estimator on target data according to an
493+
objective function regularized by the euclidean distance between
494+
the source and target `alpha`:
495+
496+
.. math::
497+
498+
\\alpha_T = \\underset{\\alpha \in \\mathbb{R}^n}{\\text{argmin}}
499+
\\, ||K_{TS} \\alpha - y_T||^2 + \\lambda ||\\alpha - \\alpha_S||^2
500+
501+
Where:
502+
503+
- :math:`\\alpha_T` are the target model coeficients.
504+
- :math:`\\alpha_S = \\underset{\\alpha \\in \\mathbb{R}^n}{\\text{argmin}}
505+
\\, ||K_{SS} \\alpha - y_S||^2` are the source model coeficients.
506+
- :math:`y_S, y_T` are respectively the source and
507+
the target labels.
508+
- :math:`K_{SS}` is the pariwise kernel distance matrix between source
509+
input data.
510+
- :math:`K_{TS}` is the pariwise kernel distance matrix between target
511+
and source input data.
512+
- :math:`n` is the number of source data in :math:`X_S`
513+
- :math:`\\lambda` is a trade-off parameter. The larger :math:`\\lambda`
514+
the closer the target model will be from the source model.
515+
516+
The ``estimator`` given to ``RegularTransferGP`` should be from classes
517+
``sklearn.gaussian_process.GaussianProcessRegressor`` or
518+
``sklearn.gaussian_process.GaussianProcessClassifier``
519+
520+
Parameters
521+
----------
522+
lambda_ : float (default=1.0)
523+
Trade-Off parameter. For large ``lambda_``, the
524+
target model will be similar to the source model.
525+
526+
Attributes
527+
----------
528+
estimator_ : Same class as estimator
529+
Fitted Estimator.
530+
531+
Examples
532+
--------
533+
>>> from sklearn.gaussian_process import GaussianProcessRegressor
534+
>>> from sklearn.gaussian_process.kernels import Matern, WhiteKernel
535+
>>> from adapt.utils import make_regression_da
536+
>>> from adapt.parameter_based import RegularTransferGP
537+
>>> Xs, ys, Xt, yt = make_regression_da()
538+
>>> kernel = Matern() + WhiteKernel()
539+
>>> src_model = GaussianProcessRegressor(kernel)
540+
>>> src_model.fit(Xs, ys)
541+
>>> print(src_model.score(Xt, yt))
542+
-2.3409379221035382
543+
>>> tgt_model = RegularTransferGP(src_model, lambda_=1.)
544+
>>> tgt_model.fit(Xt[:3], yt[:3])
545+
>>> tgt_model.score(Xt, yt)
546+
-0.21947435769240653
547+
548+
See also
549+
--------
550+
RegularTransferLR, RegularTransferNN
551+
552+
References
553+
----------
554+
.. [1] `[1] <https://www.microsoft.com/en-us/research/wp-\
555+
content/uploads/2004/07/2004-chelba-emnlp.pdf>`_ C. Chelba and \
556+
A. Acero. "Adaptation of maximum entropy classifier: Little data \
557+
can help a lot". In EMNLP, 2004.
558+
"""
559+
560+
def __init__(self,
561+
estimator=None,
562+
Xt=None,
563+
yt=None,
564+
lambda_=1.,
565+
copy=True,
566+
verbose=1,
567+
random_state=None,
568+
**params):
569+
570+
if not hasattr(estimator, "kernel_"):
571+
raise ValueError("`estimator` argument has no ``kernel_`` attribute, "
572+
"please call `fit` on `estimator` or use "
573+
"another estimator as `GaussianProcessRegressor` or "
574+
"`GaussianProcessClassifier`.")
575+
576+
estimator = check_fitted_estimator(estimator)
577+
578+
names = self._get_param_names()
579+
kwargs = {k: v for k, v in locals().items() if k in names}
580+
kwargs.update(params)
581+
super().__init__(**kwargs)
582+
583+
584+
def fit(self, Xt=None, yt=None, **fit_params):
585+
"""
586+
Fit RegularTransferGP.
587+
588+
Parameters
589+
----------
590+
Xt : numpy array (default=None)
591+
Target input data.
592+
593+
yt : numpy array (default=None)
594+
Target output data.
595+
596+
fit_params : key, value arguments
597+
Not used. Here for sklearn compatibility.
598+
599+
Returns
600+
-------
601+
self : returns an instance of self
602+
"""
603+
Xt, yt = self._get_target_data(Xt, yt)
604+
Xt, yt = check_arrays(Xt, yt)
605+
set_random_seed(self.random_state)
606+
607+
self.estimator_ = check_estimator(self.estimator,
608+
copy=self.copy,
609+
force_copy=True)
610+
611+
if isinstance(self.estimator, GaussianProcessRegressor):
612+
src_linear_model = LinearRegression(fit_intercept=False)
613+
src_linear_model.coef_ = self.estimator_.alpha_.transpose()
614+
615+
Kt = self.estimator_.kernel_(Xt, self.estimator_.X_train_)
616+
tgt_linear_model = RegularTransferLR(src_linear_model, lambda_=self.lambda_)
617+
618+
tgt_linear_model.fit(Kt, yt)
619+
620+
self.estimator_.alpha_ = np.copy(tgt_linear_model.coef_).transpose()
621+
622+
elif isinstance(self.estimator, GaussianProcessClassifier):
623+
624+
if hasattr(self.estimator_.base_estimator_, "estimators_"):
625+
for i in range(len(self.estimator_.base_estimator_.estimators_)):
626+
c = self.estimator_.classes_[i]
627+
if sum(yt == c) > 0:
628+
yt_c = np.zeros(yt.shape[0])
629+
yt_c[yt == c] = 1
630+
self.estimator_.base_estimator_.estimators_[i] = self._fit_one_vs_one_classifier(
631+
self.estimator_.base_estimator_.estimators_[i], Xt, yt_c)
632+
633+
else:
634+
self.estimator_.base_estimator_ = self._fit_one_vs_one_classifier(
635+
self.estimator_.base_estimator_, Xt, yt)
636+
return self
637+
638+
639+
def _fit_one_vs_one_classifier(self, estimator, Xt, yt):
640+
src_linear_model = LinearRegression(fit_intercept=False)
641+
src_linear_model.coef_ = (estimator.y_train_ - estimator.pi_)
642+
src_linear_model.classes_ = estimator.classes_
643+
Kt = estimator.kernel_(Xt, estimator.X_train_)
644+
645+
tgt_linear_model = RegularTransferLC(src_linear_model, lambda_=self.lambda_)
646+
647+
tgt_linear_model.fit(Kt, yt)
648+
649+
estimator.pi_ = (estimator.y_train_ - np.copy(tgt_linear_model.coef_).ravel())
650+
return estimator

src_docs/_templates/layout.html

+3
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@
6666
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferLR") }}">RegularTransferLR</a></li>
6767
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferLC") }}">RegularTransferLC</a></li>
6868
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferNN") }}">RegularTransferNN</a></li>
69+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferGP") }}">RegularTransferGP</a></li>
6970
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.FineTuning") }}">FineTuning</a></li>
7071
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferTreeClassifier") }}">TransferTreeClassifier</a></li>
7172
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferForestClassifier") }}">TransferForestClassifier</a></li>
73+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferTreeSelector") }}">TransferTreeSelector</a></li>
74+
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferForestSelector") }}">TransferForestSelector</a></li>
7275
</ul>
7376
</li>
7477
<li class="toctree-l1"><a class="reference internal" href="{{ pathto("contents") }}{{ contents }}{{ "adapt-metrics" }}">Metrics</a><ul>

src_docs/contents.rst

+3
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,12 @@ the **source** data are adapted to build a suited model for the **task** on the
146146
parameter_based.RegularTransferLR
147147
parameter_based.RegularTransferLC
148148
parameter_based.RegularTransferNN
149+
parameter_based.RegularTransferGP
149150
parameter_based.FineTuning
150151
parameter_based.TransferTreeClassifier
151152
parameter_based.TransferForestClassifier
153+
parameter_based.TransferTreeSelector
154+
parameter_based.TransferForestSelector
152155

153156

154157
.. _adapt.metrics:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
:ref:`adapt.parameter_based <adapt.parameter_based>`.RegularTransferGP
2+
===========================================================================
3+
4+
.. currentmodule:: adapt.parameter_based
5+
6+
.. autoclass:: RegularTransferGP
7+
:no-members:
8+
:no-inherited-members:
9+
:no-special-members:
10+
11+
12+
13+
14+
.. rubric:: Methods
15+
16+
.. autosummary::
17+
18+
~RegularTransferGP.__init__
19+
~RegularTransferGP.fit
20+
~RegularTransferGP.fit_estimator
21+
~RegularTransferGP.get_params
22+
~RegularTransferGP.predict
23+
~RegularTransferGP.predict_estimator
24+
~RegularTransferGP.score
25+
~RegularTransferGP.set_params
26+
~RegularTransferGP.unsupervised_score
27+
28+
29+
.. automethod:: __init__
30+
.. automethod:: fit
31+
.. automethod:: fit_estimator
32+
.. automethod:: get_params
33+
.. automethod:: predict
34+
.. automethod:: predict_estimator
35+
.. automethod:: score
36+
.. automethod:: set_params
37+
.. automethod:: unsupervised_score
38+
39+
40+
41+
42+
.. raw:: html
43+
44+
<h2> Examples </h2>
45+
46+
.. include:: ../gallery/RegularTransferGP.rst
47+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
:ref:`adapt.parameter_based <adapt.parameter_based>`.TransferForestSelector
2+
================================================================================
3+
4+
.. currentmodule:: adapt.parameter_based
5+
6+
.. autoclass:: TransferForestSelector
7+
:no-members:
8+
:no-inherited-members:
9+
:no-special-members:
10+
11+
12+
13+
14+
.. rubric:: Methods
15+
16+
.. autosummary::
17+
18+
~TransferForestSelector.__init__
19+
~TransferForestSelector.fit
20+
~TransferForestSelector.fit_estimator
21+
~TransferForestSelector.get_params
22+
~TransferForestSelector.model_selection
23+
~TransferForestSelector.predict
24+
~TransferForestSelector.predict_estimator
25+
~TransferForestSelector.score
26+
~TransferForestSelector.set_params
27+
~TransferForestSelector.unsupervised_score
28+
29+
30+
.. automethod:: __init__
31+
.. automethod:: fit
32+
.. automethod:: fit_estimator
33+
.. automethod:: get_params
34+
.. automethod:: model_selection
35+
.. automethod:: predict
36+
.. automethod:: predict_estimator
37+
.. automethod:: score
38+
.. automethod:: set_params
39+
.. automethod:: unsupervised_score
40+
41+
42+
43+
44+
.. raw:: html
45+
46+
<h2> Examples </h2>
47+
48+
.. include:: ../gallery/TransferForestSelector.rst
49+

0 commit comments

Comments
 (0)