Skip to content

Commit 664a26b

Browse files
committed
test.
1 parent 4aee105 commit 664a26b

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

python-package/xgboost/testing/multi_target.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for multi-target training."""
22

3+
import json
34
from typing import Dict, Optional, Tuple
45

56
import numpy as np
@@ -248,3 +249,29 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
248249
evals_result_0["Train"]["rmse"], evals_result_2["Train"]["rmse"]
249250
)
250251
assert_allclose(device, booster_0.inplace_predict(X), booster_2.inplace_predict(X))
252+
253+
254+
def run_eta(device: Device) -> None:
255+
from sklearn.datasets import make_regression
256+
257+
X, y = make_regression(512, 16, random_state=2025, n_targets=3)
258+
params = {
259+
"device": device,
260+
"multi_strategy": "multi_output_tree",
261+
"learning_rate": 1.0,
262+
"debug_synchronize": True,
263+
"base_score": 0.0,
264+
}
265+
Xy = QuantileDMatrix(X, y)
266+
booster_0 = train(params, Xy, num_boost_round=1)
267+
params["learning_rate"] = 0.1
268+
booster_1 = train(params, Xy, num_boost_round=1)
269+
params["learning_rate"] = 2.0
270+
booster_2 = train(params, Xy, num_boost_round=1)
271+
272+
predt_0 = booster_0.predict(Xy)
273+
predt_1 = booster_1.predict(Xy)
274+
predt_2 = booster_2.predict(Xy)
275+
276+
np.testing.assert_allclose(predt_0, predt_1 * 10, rtol=1e-6)
277+
np.testing.assert_allclose(predt_0 * 2, predt_2, rtol=1e-6)

tests/python-gpu/test_gpu_multi_target.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
run_multilabel,
55
run_reduced_grad,
66
run_with_iter,
7+
run_eta,
78
)
89

910

@@ -24,3 +25,7 @@ def test_reduced_grad() -> None:
2425
def test_with_iter() -> None:
2526
with config_context(use_rmm=True):
2627
run_with_iter("cuda")
28+
29+
30+
def test_eta() -> None:
31+
run_eta("cuda")

0 commit comments

Comments
 (0)