Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early stopping bugs: 1. training data is not ignored; 2. first_metric_only is actually checking the first metric of the first validation data #2410

Closed
Jingyu-Fan opened this issue Sep 14, 2019 · 2 comments

Comments

@Jingyu-Fan
Copy link

Environment info

Operating System: Linux

CPU/GPU model: CPU

C++/Python/R version: Python

LightGBM version or commit hash: 2.2.4

There might be bugs here:

def _callback(env):
if not cmp_op:
_init(env)
if not enabled[0]:
return
for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
best_score[i] = score
best_iter[i] = env.iteration
best_score_list[i] = env.evaluation_result_list
elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose:
print('Early stopping, best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
if env.iteration == env.end_iteration - 1:
if verbose:
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
if first_metric_only: # the only first metric is used for early stopping
break

env.evaluation_result_list contains the evaluations for all the metrics and all validation data(including the training data if added). So the behavior of early stopping is not the way the document says:

But the training data is ignored anyway. To check only the first metric, set the first_metric_only parameter to True in params.

test code only has one validation data, so the bug did not get caught:

def test_early_stopping_for_only_first_metric(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'regression',
'metric': 'None',
'verbose': -1
}
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
decreasing_generator = itertools.count(0, -1)
def decreasing_metric(preds, train_data):
return ('decreasing_metric', next(decreasing_generator), False)
def constant_metric(preds, train_data):
return ('constant_metric', 0.0, False)
# test that all metrics are checked (default behaviour)
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)],
early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 1)
# test that only the first metric is checked
gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)],
early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 20)
# ... change the order of metrics
gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [constant_metric(preds, train_data),
decreasing_metric(preds, train_data)],
early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 1)

Here is the code to demonstrate the issue. Only one line was added to early_stopping callback to print env.evaluation_result_list for each iteration.

import warnings
from operator import gt, lt
import numpy as np
import lightgbm as lgb
from lightgbm.callback import EarlyStopException
from lightgbm.callback import _format_eval_result
from lightgbm.compat import range_

def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
    best_score = []
    best_score_avg = []
    best_iter = []
    best_iter_avg = []
    best_score_list = []
    best_score_avg_list = []
    cmp_op = []
    enabled = [True]

    def _init(env):
        enabled[0] = not any((boost_alias in env.params
                              and env.params[boost_alias] == 'dart') for boost_alias in ('boosting',
                                                                                         'boosting_type',
                                                                                         'boost'))
        if not enabled[0]:
            warnings.warn('Early stopping is not available in dart mode')
            return
        if not env.evaluation_result_list:
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
        if verbose:
            msg = "Training until validation scores don't improve for {} rounds."
            print(msg.format(stopping_rounds))
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
            best_score_list.append(None)
            if eval_ret[3]:
                best_score.append(float('-inf'))
                # best_score_avg = float('-inf')
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                # best_score_avg = float('inf')
                cmp_op.append(lt)
        best_score_avg.append(None)
        best_iter_avg.append(None)
        best_score_avg_list.append(None)

    def _callback(env):
        if not cmp_op:
            _init(env)
        if not enabled[0]:
            return
        # added line to print env.evaluation_result_list
        print(env.evaluation_result_list)
        for i in range_(len(env.evaluation_result_list)):
            score = env.evaluation_result_list[i][2]
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
                best_score[i] = score
                best_iter[i] = env.iteration
                best_score_list[i] = env.evaluation_result_list
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
            if env.iteration == env.end_iteration - 1:
                if verbose:
                    print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
            if first_metric_only:  # the only first metric is used for early stopping
                break

    _callback.order = 30
    return _callback



data = np.random.random((500, 10))
y = [1] * 250 + [0] * 250
lgb_train = lgb.Dataset(data, y, free_raw_data=True)


data = np.random.random((500, 10))
y = [1] * 250 + [0] * 250
lgb_test = lgb.Dataset(data, y, free_raw_data=True)


params = {
    'objective': 'binary',
    'verbose': 1,
    'metric': ['binary_logloss', 'auc']
}

gbm = lgb.train(params=params,
                train_set=lgb_train,
                valid_sets=[lgb_train, lgb_test],
                num_boost_round=6,
                callbacks=[early_stopping(1, first_metric_only=False, verbose=True)]
                )

And here is the printed output:

[1]	training's auc: 0.741264	training's binary_logloss: 0.675381	valid_1's auc: 0.475968	valid_1's binary_logloss: 0.695726
Training until validation scores don't improve for 1 rounds.
[('training', u'auc', '0.741264', True), ('training', u'binary_logloss', '0.6753807446588581', False), ('valid_1', u'auc', '0.475968', True), ('valid_1', u'binary_logloss', '0.6957255240695474', False)]
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[2]	training's auc: 0.804496	training's binary_logloss: 0.659374	valid_1's auc: 0.471824	valid_1's binary_logloss: 0.698373
[('training', u'auc', '0.804496', True), ('training', u'binary_logloss', '0.6593740400597452', False), ('valid_1', u'auc', '0.471824', True), ('valid_1', u'binary_logloss', '0.6983732401328614', False)]
Early stopping, best iteration is:
[1]	training's auc: 0.741264	training's binary_logloss: 0.675381	valid_1's auc: 0.475968	valid_1's binary_logloss: 0.695726
@guolinke
Copy link
Collaborator

yeah, it is work in progress. refer to #2209

@StrikerRUS
Copy link
Collaborator

  1. We have already gotten the issue about that: Unexpected Behavior for Early Stopping with Custom Metric #2371. Please follow up the discussion there.
  2. It has been fixed in [python] Bug fix for first_metric_only on earlystopping. #2209.

@lock lock bot locked as resolved and limited conversation to collaborators Mar 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants