Skip to content

Commit 62384af

Browse files
fix kliep + add test rtgp
1 parent 8c26c9a commit 62384af

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

adapt/instance_based/_kliep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self,
221221
max_centers=100,
222222
cv=5,
223223
algo="FW",
224-
lr=np.logspace(-3,1,5),
224+
lr=[0.001, 0.01, 0.1, 1.0, 10.0],
225225
tol=1e-6,
226226
max_iter=2000,
227227
copy=True,

tests/test_regular.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,17 @@ def test_regulargp_classif():
225225
tgt_model = RegularTransferGP(src_model, lambda_=1.)
226226
tgt_model.fit(Xt[:3], yt[:3])
227227
score2 = tgt_model.score(Xt, yt)
228+
assert score1 < score2
229+
230+
231+
def test_regulargp_multi_classif():
232+
Xs, ys, Xt, yt = make_classification_da()
233+
ys[:5] = 3
234+
kernel = Matern() + WhiteKernel()
235+
src_model = GaussianProcessClassifier(kernel)
236+
src_model.fit(Xs, ys)
237+
score1 = src_model.score(Xt, yt)
238+
tgt_model = RegularTransferGP(src_model, lambda_=1.)
239+
tgt_model.fit(Xt[:3], yt[:3])
240+
score2 = tgt_model.score(Xt, yt)
228241
assert score1 < score2

0 commit comments

Comments
 (0)