Skip to content

Commit a9c77d3

Browse files
committed
Fix ordering of Python callbacks.
1 parent e166942 commit a9c77d3

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

python-package/xgboost/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(
160160
output_margin: bool = True,
161161
is_cv: bool = False,
162162
) -> None:
163-
self.callbacks = set(callbacks)
163+
self.callbacks = list(dict.fromkeys(callbacks))
164164
for cb in callbacks:
165165
if not isinstance(cb, TrainingCallback):
166166
raise TypeError("callback must be an instance of `TrainingCallback`.")

tests/python/test_callback.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
from collections import namedtuple
4-
from typing import Tuple, Union
4+
from typing import Tuple, Union, Optional
55

66
import numpy as np
77
import pytest
@@ -384,3 +384,25 @@ def test_attribute_error(self, breast_cancer: BreastCancer) -> None:
384384

385385
with pytest.raises(AttributeError, match="early stopping is used"):
386386
booster.best_score
387+
388+
def test_preserve_order(self) -> None:
389+
"""Test the ordering of the callbacks is preserved."""
390+
X, y, w = tm.make_regression(256, 16, False)
391+
fst_call: Optional[int] = None
392+
393+
# If we use Python `set`, Cb1 is ordered before Cb2. This test makes sure Cb2 is
394+
# called before Cb1.
395+
class Cb2(xgb.callback.TrainingCallback):
396+
def before_iteration(self, model, epoch: int, evals_log) -> bool:
397+
nonlocal fst_call
398+
assert fst_call is None or fst_call == 2
399+
fst_call = 2
400+
return False
401+
402+
class Cb1(xgb.callback.TrainingCallback):
403+
def before_iteration(self, model, epoch: int, evals_log) -> bool:
404+
assert fst_call == 2
405+
return False
406+
407+
callbacks = [Cb2(), Cb1()]
408+
xgb.train({}, dtrain=xgb.QuantileDMatrix(X, y, weight=w), callbacks=callbacks)

0 commit comments

Comments
 (0)