|
1 | 1 | import os |
2 | 2 | import tempfile |
3 | 3 | from collections import namedtuple |
4 | | -from typing import Tuple, Union |
| 4 | +from typing import Tuple, Union, Optional |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pytest |
@@ -384,3 +384,25 @@ def test_attribute_error(self, breast_cancer: BreastCancer) -> None: |
384 | 384 |
|
385 | 385 | with pytest.raises(AttributeError, match="early stopping is used"): |
386 | 386 | booster.best_score |
| 387 | + |
| 388 | + def test_preserve_order(self) -> None: |
| 389 | + """Test the ordering of the callbacks is preserved.""" |
| 390 | + X, y, w = tm.make_regression(256, 16, False) |
| 391 | + fst_call: Optional[int] = None |
| 392 | + |
| 393 | + # If we use Python `set`, Cb1 is ordered before Cb2. This test makes sure Cb2 is |
| 394 | + # called before Cb1. |
| 395 | + class Cb2(xgb.callback.TrainingCallback): |
| 396 | + def before_iteration(self, model, epoch: int, evals_log) -> bool: |
| 397 | + nonlocal fst_call |
| 398 | + assert fst_call is None or fst_call == 2 |
| 399 | + fst_call = 2 |
| 400 | + return False |
| 401 | + |
| 402 | + class Cb1(xgb.callback.TrainingCallback): |
| 403 | + def before_iteration(self, model, epoch: int, evals_log) -> bool: |
| 404 | + assert fst_call == 2 |
| 405 | + return False |
| 406 | + |
| 407 | + callbacks = [Cb2(), Cb1()] |
| 408 | + xgb.train({}, dtrain=xgb.QuantileDMatrix(X, y, weight=w), callbacks=callbacks) |
0 commit comments