@@ -71,6 +71,14 @@ class CallbackEnv:
71
71
evaluation_result_list : Optional [_ListOfEvalResultTuples ]
72
72
73
73
74
+ def _is_using_cv (env : CallbackEnv ) -> bool :
75
+ """Check if model in callback env is a CVBooster."""
76
+ # this import is here to avoid a circular import
77
+ from .engine import CVBooster
78
+
79
+ return isinstance (env .model , CVBooster )
80
+
81
+
74
82
def _format_eval_result (value : _EvalResultTuple , show_stdv : bool ) -> str :
75
83
"""Format metric string."""
76
84
dataset_name , metric_name , metric_value , * _ = value
@@ -143,16 +151,13 @@ def _init(self, env: CallbackEnv) -> None:
143
151
)
144
152
self .eval_result .clear ()
145
153
for item in env .evaluation_result_list :
146
- if len (item ) == 4 : # regular train
147
- data_name , eval_name = item [:2 ]
148
- else : # cv
149
- data_name , eval_name = item [1 ].split ()
150
- self .eval_result .setdefault (data_name , OrderedDict ())
154
+ dataset_name , metric_name , * _ = item
155
+ self .eval_result .setdefault (dataset_name , OrderedDict ())
151
156
if len (item ) == 4 :
152
- self .eval_result [data_name ].setdefault (eval_name , [])
157
+ self .eval_result [dataset_name ].setdefault (metric_name , [])
153
158
else :
154
- self .eval_result [data_name ].setdefault (f"{ eval_name } -mean" , [])
155
- self .eval_result [data_name ].setdefault (f"{ eval_name } -stdv" , [])
159
+ self .eval_result [dataset_name ].setdefault (f"{ metric_name } -mean" , [])
160
+ self .eval_result [dataset_name ].setdefault (f"{ metric_name } -stdv" , [])
156
161
157
162
def __call__ (self , env : CallbackEnv ) -> None :
158
163
if env .iteration == env .begin_iteration :
@@ -163,15 +168,16 @@ def __call__(self, env: CallbackEnv) -> None:
163
168
"Please report it at https://github.com/microsoft/LightGBM/issues"
164
169
)
165
170
for item in env .evaluation_result_list :
171
+ # for cv(), 'metric_value' is actually a mean of metric values over all CV folds
172
+ dataset_name , metric_name , metric_value , * _ = item
166
173
if len (item ) == 4 :
167
- data_name , eval_name , result = item [: 3 ]
168
- self .eval_result [data_name ][ eval_name ].append (result )
174
+ # train()
175
+ self .eval_result [dataset_name ][ metric_name ].append (metric_value )
169
176
else :
170
- data_name , eval_name = item [1 ].split ()
171
- res_mean = item [2 ]
172
- res_stdv = item [4 ] # type: ignore[misc]
173
- self .eval_result [data_name ][f"{ eval_name } -mean" ].append (res_mean )
174
- self .eval_result [data_name ][f"{ eval_name } -stdv" ].append (res_stdv )
177
+ # cv()
178
+ metric_std_dev = item [4 ] # type: ignore[misc]
179
+ self .eval_result [dataset_name ][f"{ metric_name } -mean" ].append (metric_value )
180
+ self .eval_result [dataset_name ][f"{ metric_name } -stdv" ].append (metric_std_dev )
175
181
176
182
177
183
def record_evaluation (eval_result : Dict [str , Dict [str , List [Any ]]]) -> Callable :
@@ -304,15 +310,15 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
304
310
def _lt_delta (self , curr_score : float , best_score : float , delta : float ) -> bool :
305
311
return curr_score < best_score - delta
306
312
307
- def _is_train_set (self , ds_name : str , eval_name : str , env : CallbackEnv ) -> bool :
313
+ def _is_train_set (self , dataset_name : str , env : CallbackEnv ) -> bool :
308
314
"""Check, by name, if a given Dataset is the training data."""
309
315
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
310
316
# and those metrics are considered for early stopping
311
- if ds_name == "cv_agg" and eval_name == "train" :
317
+ if _is_using_cv ( env ) and dataset_name == "train" :
312
318
return True
313
319
314
320
# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
315
- if isinstance (env .model , Booster ) and ds_name == env .model ._train_data_name :
321
+ if isinstance (env .model , Booster ) and dataset_name == env .model ._train_data_name :
316
322
return True
317
323
318
324
return False
@@ -327,11 +333,13 @@ def _init(self, env: CallbackEnv) -> None:
327
333
_log_warning ("Early stopping is not available in dart mode" )
328
334
return
329
335
336
+ # get details of the first dataset
337
+ first_dataset_name , first_metric_name , * _ = env .evaluation_result_list [0 ]
338
+
330
339
# validation sets are guaranteed to not be identical to the training data in cv()
331
340
if isinstance (env .model , Booster ):
332
341
only_train_set = len (env .evaluation_result_list ) == 1 and self ._is_train_set (
333
- ds_name = env .evaluation_result_list [0 ][0 ],
334
- eval_name = env .evaluation_result_list [0 ][1 ].split (" " )[0 ],
342
+ dataset_name = first_dataset_name ,
335
343
env = env ,
336
344
)
337
345
if only_train_set :
@@ -370,8 +378,7 @@ def _init(self, env: CallbackEnv) -> None:
370
378
_log_info (f"Using { self .min_delta } as min_delta for all metrics." )
371
379
deltas = [self .min_delta ] * n_datasets * n_metrics
372
380
373
- # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
374
- self .first_metric = env .evaluation_result_list [0 ][1 ].split (" " )[- 1 ]
381
+ self .first_metric = first_metric_name
375
382
for eval_ret , delta in zip (env .evaluation_result_list , deltas ):
376
383
self .best_iter .append (0 )
377
384
if eval_ret [3 ]: # greater is better
@@ -381,15 +388,15 @@ def _init(self, env: CallbackEnv) -> None:
381
388
self .best_score .append (float ("inf" ))
382
389
self .cmp_op .append (partial (self ._lt_delta , delta = delta ))
383
390
384
- def _final_iteration_check (self , env : CallbackEnv , eval_name_splitted : List [ str ] , i : int ) -> None :
391
+ def _final_iteration_check (self , * , env : CallbackEnv , metric_name : str , i : int ) -> None :
385
392
if env .iteration == env .end_iteration - 1 :
386
393
if self .verbose :
387
394
best_score_str = "\t " .join ([_format_eval_result (x , show_stdv = True ) for x in self .best_score_list [i ]])
388
395
_log_info (
389
396
"Did not meet early stopping. " f"Best iteration is:\n [{ self .best_iter [i ] + 1 } ]\t { best_score_str } "
390
397
)
391
398
if self .first_metric_only :
392
- _log_info (f"Evaluated only: { eval_name_splitted [ - 1 ] } " )
399
+ _log_info (f"Evaluated only: { metric_name } " )
393
400
raise EarlyStopException (self .best_iter [i ], self .best_score_list [i ])
394
401
395
402
def __call__ (self , env : CallbackEnv ) -> None :
@@ -405,21 +412,18 @@ def __call__(self, env: CallbackEnv) -> None:
405
412
# self.best_score_list is initialized to an empty list
406
413
first_time_updating_best_score_list = self .best_score_list == []
407
414
for i in range (len (env .evaluation_result_list )):
408
- score = env .evaluation_result_list [i ][ 2 ]
409
- if first_time_updating_best_score_list or self .cmp_op [i ](score , self .best_score [i ]):
410
- self .best_score [i ] = score
415
+ dataset_name , metric_name , metric_value , * _ = env .evaluation_result_list [i ]
416
+ if first_time_updating_best_score_list or self .cmp_op [i ](metric_value , self .best_score [i ]):
417
+ self .best_score [i ] = metric_value
411
418
self .best_iter [i ] = env .iteration
412
419
if first_time_updating_best_score_list :
413
420
self .best_score_list .append (env .evaluation_result_list )
414
421
else :
415
422
self .best_score_list [i ] = env .evaluation_result_list
416
- # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
417
- eval_name_splitted = env .evaluation_result_list [i ][1 ].split (" " )
418
- if self .first_metric_only and self .first_metric != eval_name_splitted [- 1 ]:
423
+ if self .first_metric_only and self .first_metric != metric_name :
419
424
continue # use only the first metric for early stopping
420
425
if self ._is_train_set (
421
- ds_name = env .evaluation_result_list [i ][0 ],
422
- eval_name = eval_name_splitted [0 ],
426
+ dataset_name = dataset_name ,
423
427
env = env ,
424
428
):
425
429
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
@@ -430,9 +434,9 @@ def __call__(self, env: CallbackEnv) -> None:
430
434
)
431
435
_log_info (f"Early stopping, best iteration is:\n [{ self .best_iter [i ] + 1 } ]\t { eval_result_str } " )
432
436
if self .first_metric_only :
433
- _log_info (f"Evaluated only: { eval_name_splitted [ - 1 ] } " )
437
+ _log_info (f"Evaluated only: { metric_name } " )
434
438
raise EarlyStopException (self .best_iter [i ], self .best_score_list [i ])
435
- self ._final_iteration_check (env , eval_name_splitted , i )
439
+ self ._final_iteration_check (env = env , metric_name = metric_name , i = i )
436
440
437
441
438
442
def _should_enable_early_stopping (stopping_rounds : Any ) -> bool :
0 commit comments