Skip to content

Commit 55c4e5c

Browse files
authored
[backport] Fix ordering of Python callbacks. (#11812) (#11818)
1 parent 3bfe966 commit 55c4e5c

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

python-package/xgboost/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
output_margin: bool = True,
162162
is_cv: bool = False,
163163
) -> None:
164-
self.callbacks = set(callbacks)
164+
self.callbacks = list(dict.fromkeys(callbacks))
165165
for cb in callbacks:
166166
if not isinstance(cb, TrainingCallback):
167167
raise TypeError("callback must be an instance of `TrainingCallback`.")

tests/python/test_callback.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
from collections import namedtuple
4-
from typing import Tuple, Union
4+
from typing import Optional, Tuple, Union
55

66
import numpy as np
77
import pytest
@@ -384,3 +384,25 @@ def test_attribute_error(self, breast_cancer: BreastCancer) -> None:
384384

385385
with pytest.raises(AttributeError, match="early stopping is used"):
386386
booster.best_score
387+
388+
def test_preserve_order(self) -> None:
389+
"""Test the ordering of the callbacks is preserved."""
390+
X, y, w = tm.make_regression(256, 16, False)
391+
fst_call: Optional[int] = None
392+
393+
# If we use Python `set`, Cb1 is ordered before Cb2. This test makes sure Cb2 is
394+
# called before Cb1.
395+
class Cb2(xgb.callback.TrainingCallback):
396+
def before_iteration(self, model, epoch: int, evals_log) -> bool:
397+
nonlocal fst_call
398+
assert fst_call is None or fst_call == 2
399+
fst_call = 2
400+
return False
401+
402+
class Cb1(xgb.callback.TrainingCallback):
403+
def before_iteration(self, model, epoch: int, evals_log) -> bool:
404+
assert fst_call == 2
405+
return False
406+
407+
callbacks = [Cb2(), Cb1()]
408+
xgb.train({}, dtrain=xgb.QuantileDMatrix(X, y, weight=w), callbacks=callbacks)

0 commit comments

Comments
 (0)