Skip to content

Commit bdd60a7

Browse files
committed
Optionally return platt scaling classifier.
1 parent 48c3dcf commit bdd60a7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

calibration/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def bootstrap_std(data: List[T], estimator=None, num_samples=100) -> Tuple[float
453453

454454
# Re-Calibration utilities.
455455

456-
def get_platt_scaler(model_probs, labels):
456+
def get_platt_scaler(model_probs, labels, get_clf=False):
457457
clf = LogisticRegression(C=1e10, solver='lbfgs')
458458
eps = 1e-12
459459
model_probs = model_probs.astype(dtype=np.float64)
@@ -468,6 +468,8 @@ def calibrator(probs):
468468
x = x * clf.coef_[0] + clf.intercept_
469469
output = 1 / (1 + np.exp(-x))
470470
return output
471+
if get_clf:
472+
return calibrator, clf
471473
return calibrator
472474

473475

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="uncertainty-calibration",
8-
version="0.1.2",
8+
version="0.1.3",
99
author="Ananya Kumar",
1010
author_email="[email protected]",
1111
description="Utilities to calibrate model uncertainties and measure calibration.",

0 commit comments

Comments
 (0)