Skip to content

Support early stopping on 1.2-2 #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 1.2-2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sagemaker_xgboost_container/algorithm_mode/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
# Evaluation metrics to use with train() API
tuning_objective_metric_param = train_cfg.pop("_tuning_objective_metric", None)
eval_metric = train_cfg.get("eval_metric")
cleaned_eval_metric, configured_feval = train_utils.get_eval_metrics_and_feval(

cleaned_eval_metric, configured_feval, maximize_feval_metric = train_utils.get_eval_metrics_and_feval(
tuning_objective_metric_param, eval_metric)

if cleaned_eval_metric:
train_cfg['eval_metric'] = cleaned_eval_metric
else:
Expand All @@ -217,7 +219,8 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di

bst = xgb.train(train_cfg, train_dmatrix, num_boost_round=num_round-iteration, evals=watchlist,
feval=configured_feval, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks, xgb_model=xgb_model, verbose_eval=False)
maximize=maximize_feval_metric, callbacks=callbacks, xgb_model=xgb_model,
verbose_eval=False)

else:
num_cv_round = train_cfg.pop("_num_cv_round", 1)
Expand Down Expand Up @@ -249,7 +252,7 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
logging.info("Train cross validation fold {}".format((len(bst) % kfold) + 1))
booster = xgb.train(train_cfg, cv_train_dmatrix, num_boost_round=num_round-iteration,
evals=watchlist, feval=configured_feval, evals_result=evals_result,
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=early_stopping_rounds, maximize=maximize_feval_metric,
callbacks=callbacks, xgb_model=xgb_model, verbose_eval=False)
bst.append(booster)
evals_results.append(evals_result)
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker_xgboost_container/algorithm_mode/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
from sagemaker_xgboost_container.metrics.custom_metrics import get_custom_metrics, configure_feval
from sagemaker_xgboost_container.constants.xgb_constants import XGB_MAXIMIZE_METRICS


HPO_SEPARATOR = ':'
Expand All @@ -21,10 +22,12 @@
# These are helper functions for parsing the list of metrics to be outputted
def get_union_metrics(metric_a, metric_b):
"""Union of metric_a and metric_b
We make sure the tuning objective metrics are in the end of the list. XGBoost internal early stopping uses
the last metric (in this case the tuning objective metric) for early stopping.

:param metric_a: list
:param metric_b: list
:return: Union metrics list from metric_a and metric_b
:param metric_a: list, tuning objective metrics
:param metric_b: list, eval metrics defined within xgboost
:return: Union metrics list from metric_a and metric_b where metrics in metric_a are in the end
"""
if metric_a is None and metric_b is None:
return None
Expand All @@ -33,7 +36,12 @@ def get_union_metrics(metric_a, metric_b):
elif metric_b is None:
return metric_a
else:
metric_list = list(set(metric_a).union(metric_b))
for metric in metric_a:
if metric in metric_b:
# remove duplicate metrics
metric_b.remove(metric)
metric_list = metric_b + metric_a
assert metric_list[-1] == metric_a[-1]
return metric_list


Expand All @@ -59,15 +67,17 @@ def get_eval_metrics_and_feval(tuning_objective_metric_param, eval_metric):

union_metrics = get_union_metrics(tuning_objective_metric, eval_metric)

maximize_feval_metric = None
if union_metrics is not None:
feval_metrics = get_custom_metrics(union_metrics)
if feval_metrics:
configured_eval = configure_feval(feval_metrics)
cleaned_eval_metrics = list(set(union_metrics) - set(feval_metrics))
cleaned_eval_metrics = [metric for metric in union_metrics if metric not in feval_metrics]
maximize_feval_metric = True if feval_metrics[-1] in XGB_MAXIMIZE_METRICS else False
else:
cleaned_eval_metrics = union_metrics

return cleaned_eval_metrics, configured_eval
return cleaned_eval_metrics, configured_eval, maximize_feval_metric


def cleanup_dir(dir, file_prefix):
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker_xgboost_container/metrics/custom_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def r2(preds, dtrain):
}


def get_custom_metrics(eval_metrics):
def get_custom_metrics(union_metrics):
"""Get container defined metrics from metrics list."""
return set(eval_metrics).intersection(CUSTOM_METRICS.keys())
return [metric for metric in union_metrics if metric in CUSTOM_METRICS.keys()]


def configure_feval(custom_metric_list):
Expand Down