Skip to content

Commit 0a866ec

Browse files
authored
MNT New options for higgs-boson benchmark (scikit-learn#16779)
1 parent fa1ea2a commit 0a866ec

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

benchmarks/bench_hist_gradient_boosting_higgsboson.py

+35-40
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
parser.add_argument('--learning-rate', type=float, default=1.)
2626
parser.add_argument('--subsample', type=int, default=None)
2727
parser.add_argument('--max-bins', type=int, default=255)
28+
parser.add_argument('--no-predict', action="store_true", default=False)
29+
parser.add_argument('--cache-loc', type=str, default='/tmp')
2830
args = parser.parse_args()
2931

3032
HERE = os.path.dirname(__file__)
3133
URL = ("https://archive.ics.uci.edu/ml/machine-learning-databases/00280/"
3234
"HIGGS.csv.gz")
33-
m = Memory(location='/tmp', mmap_mode='r')
35+
m = Memory(location=args.cache_loc, mmap_mode='r')
3436

3537
n_leaf_nodes = args.n_leaf_nodes
3638
n_trees = args.n_trees
@@ -56,6 +58,27 @@ def load_data():
5658
return df
5759

5860

61+
def fit(est, data_train, target_train, libname):
62+
print(f"Fitting a {libname} model...")
63+
tic = time()
64+
est.fit(data_train, target_train)
65+
toc = time()
66+
print(f"fitted in {toc - tic:.3f}s")
67+
68+
69+
def predict(est, data_test, target_test):
70+
if args.no_predict:
71+
return
72+
tic = time()
73+
predicted_test = est.predict(data_test)
74+
predicted_proba_test = est.predict_proba(data_test)
75+
toc = time()
76+
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
77+
acc = accuracy_score(target_test, predicted_test)
78+
print(f"predicted in {toc - tic:.3f}s, "
79+
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
80+
81+
5982
df = load_data()
6083
target = df.values[:, 0]
6184
data = np.ascontiguousarray(df.values[:, 1:])
@@ -68,8 +91,6 @@ def load_data():
6891
n_samples, n_features = data_train.shape
6992
print(f"Training set with {n_samples} records with {n_features} features.")
7093

71-
print("Fitting a sklearn model...")
72-
tic = time()
7394
est = HistGradientBoostingClassifier(loss='binary_crossentropy',
7495
learning_rate=lr,
7596
max_iter=n_trees,
@@ -78,46 +99,20 @@ def load_data():
7899
early_stopping=False,
79100
random_state=0,
80101
verbose=1)
81-
est.fit(data_train, target_train)
82-
toc = time()
83-
predicted_test = est.predict(data_test)
84-
predicted_proba_test = est.predict_proba(data_test)
85-
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
86-
acc = accuracy_score(target_test, predicted_test)
87-
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
102+
fit(est, data_train, target_train, 'sklearn')
103+
predict(est, data_test, target_test)
88104

89105
if args.lightgbm:
90-
print("Fitting a LightGBM model...")
91-
tic = time()
92-
lightgbm_est = get_equivalent_estimator(est, lib='lightgbm')
93-
lightgbm_est.fit(data_train, target_train)
94-
toc = time()
95-
predicted_test = lightgbm_est.predict(data_test)
96-
predicted_proba_test = lightgbm_est.predict_proba(data_test)
97-
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
98-
acc = accuracy_score(target_test, predicted_test)
99-
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
106+
est = get_equivalent_estimator(est, lib='lightgbm')
107+
fit(est, data_train, target_train, 'lightgbm')
108+
predict(est, data_test, target_test)
100109

101110
if args.xgboost:
102-
print("Fitting an XGBoost model...")
103-
tic = time()
104-
xgboost_est = get_equivalent_estimator(est, lib='xgboost')
105-
xgboost_est.fit(data_train, target_train)
106-
toc = time()
107-
predicted_test = xgboost_est.predict(data_test)
108-
predicted_proba_test = xgboost_est.predict_proba(data_test)
109-
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
110-
acc = accuracy_score(target_test, predicted_test)
111-
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
111+
est = get_equivalent_estimator(est, lib='xgboost')
112+
fit(est, data_train, target_train, 'xgboost')
113+
predict(est, data_test, target_test)
112114

113115
if args.catboost:
114-
print("Fitting a Catboost model...")
115-
tic = time()
116-
catboost_est = get_equivalent_estimator(est, lib='catboost')
117-
catboost_est.fit(data_train, target_train)
118-
toc = time()
119-
predicted_test = catboost_est.predict(data_test)
120-
predicted_proba_test = catboost_est.predict_proba(data_test)
121-
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
122-
acc = accuracy_score(target_test, predicted_test)
123-
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
116+
est = get_equivalent_estimator(est, lib='catboost')
117+
fit(est, data_train, target_train, 'catboost')
118+
predict(est, data_test, target_test)

0 commit comments

Comments
 (0)