Skip to content

Commit 0832e6e

Browse files
committed
Parameter tests.
1 parent 7ac273a commit 0832e6e

File tree

3 files changed

+41
-23
lines changed

3 files changed

+41
-23
lines changed

python-package/xgboost/testing/params.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,15 @@
4747
"max_cat_threshold": strategies.integers(1, 128),
4848
}
4949
)
50+
51+
lambdarank_parameter_strategy = strategies.fixed_dictionaries(
52+
{
53+
"lambdarank_unbiased": strategies.sampled_from([True, False]),
54+
"lambdarank_pair_method": strategies.sampled_from(["topk", "mean"]),
55+
"lambdarank_num_pair_per_sample": strategies.integers(1, 8),
56+
"lambdarank_bias_norm": strategies.floats(0.5, 2.0),
57+
"objective": strategies.sampled_from(
58+
["rank:ndcg", "rank:map", "rank:pairwise"]
59+
),
60+
}
61+
)

src/metric/rank_metric.cc

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ class EvalRankWithCache : public Metric {
360360
}
361361
param_.UpdateAllowUnknown(Args{});
362362
}
363-
363+
void Configure(Args const&) override {
364+
// do not configure, otherwise the ndcg param will be forced into the same as the one in
365+
// objective.
366+
}
364367
void LoadConfig(Json const& in) override {
365368
if (IsA<Null>(in)) {
366369
return;
@@ -418,11 +421,6 @@ double Finalize(double score, double sw) {
418421
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
419422
public:
420423
using EvalRankWithCache::EvalRankWithCache;
421-
422-
void Configure(Args const&) override {
423-
// do not configure, otherwise the ndcg param will be forced into the same as the one in
424-
// objective.
425-
}
426424
const char* Name() const override { return name_.c_str(); }
427425

428426
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
@@ -484,17 +482,6 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
484482
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
485483
public:
486484
using EvalRankWithCache::EvalRankWithCache;
487-
488-
void Configure(Args const&) override {
489-
// do not configure, otherwise the ndcg param will be forced into the same as the one in
490-
// objective.
491-
}
492-
void SaveConfig(Json* p_out) const override {
493-
auto& out = *p_out;
494-
out["name"] = String{this->Name()};
495-
out["map_param"] = ToJson(param_);
496-
}
497-
void LoadConfig(Json const& in) override { FromJson(in["map_param"], &param_); }
498485
const char* Name() const override { return name_.c_str(); }
499486

500487
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,

tests/python/test_ranking.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import numpy as np
66
import pytest
77
from scipy.sparse import csr_matrix
8+
from hypothesis import given, note, settings
89

910
import xgboost
1011
from xgboost import testing as tm
12+
from xgboost.testing.params import lambdarank_parameter_strategy
1113

1214

1315
def test_ranking_with_unweighted_data():
@@ -74,6 +76,29 @@ def test_ranking_with_weighted_data():
7476
assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:]))
7577

7678

79+
def test_error_msg(self) -> None:
80+
X, y, qid, w = tm.make_ltr(10, 2, 2, 2)
81+
ranker = xgboost.XGBRanker()
82+
with pytest.raises(ValueError, match=r"equal to the number of query groups"):
83+
ranker.fit(X, y, qid=qid, sample_weight=y)
84+
85+
86+
@given(lambdarank_parameter_strategy)
87+
@settings(deadline=None, print_blob=True)
88+
def test_lambdarank_parameters(params):
89+
if params["objective"] == "rank:map":
90+
rel = 2
91+
else:
92+
rel = 5
93+
X, y, q, w = tm.make_ltr(4096, 3, 13, rel)
94+
ranker = xgboost.XGBRanker(tree_method="hist", n_estimators=64, **params)
95+
ranker.fit(X, y, qid=q, sample_weight=w, eval_set=[(X, y)], eval_qid=[q])
96+
for k, v in ranker.evals_result()["validation_0"].items():
97+
note(v)
98+
assert v[-1] > v[0]
99+
assert ranker.n_features_in_ == 3
100+
101+
77102
class TestRanking:
78103
@classmethod
79104
def setup_class(cls):
@@ -119,12 +144,6 @@ def teardown_class(cls):
119144
if os.path.exists(directory):
120145
shutil.rmtree(directory)
121146

122-
def test_error_msg(self) -> None:
123-
X, y, qid, w = tm.make_ltr(10, 2, 2, 2)
124-
ranker = xgboost.XGBRanker()
125-
with pytest.raises(ValueError, match=r"equal to the number of query groups"):
126-
ranker.fit(X, y, qid=qid, sample_weight=y)
127-
128147
def test_training(self):
129148
"""
130149
Train an XGBoost ranking model

0 commit comments

Comments
 (0)