Skip to content

Commit 065f6a2

Browse files
Merge pull request #54 from antoinedemathelin/master
fix: change prediction in TrAdaBoost
2 parents 8fd2bbd + b496e89 commit 065f6a2

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

adapt/instance_based/_tradaboost.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,16 @@ def _boost(self, iboost, Xs, ys, Xt, yt,
304304
warm_start=False,
305305
**fit_params)
306306

307-
if hasattr(estimator, "predict_proba"):
308-
ys_pred = estimator.predict_proba(Xs)
309-
yt_pred = estimator.predict_proba(Xt)
310-
elif hasattr(estimator, "_predict_proba_lr"):
311-
ys_pred = estimator._predict_proba_lr(Xs)
312-
yt_pred = estimator._predict_proba_lr(Xt)
307+
if not isinstance(self, TrAdaBoostR2) and isinstance(estimator, BaseEstimator):
308+
if hasattr(estimator, "predict_proba"):
309+
ys_pred = estimator.predict_proba(Xs)
310+
yt_pred = estimator.predict_proba(Xt)
311+
elif hasattr(estimator, "_predict_proba_lr"):
312+
ys_pred = estimator._predict_proba_lr(Xs)
313+
yt_pred = estimator._predict_proba_lr(Xt)
314+
else:
315+
ys_pred = estimator.predict(Xs)
316+
yt_pred = estimator.predict(Xt)
313317
else:
314318
ys_pred = estimator.predict(Xs)
315319
yt_pred = estimator.predict(Xt)

0 commit comments

Comments
 (0)