Skip to content

Commit 1fcc26a

Browse files
authored
Set ndcg to default for LTR. (#8822)
- Add document. - Add tests. - Use `ndcg` with `topk` as default.
1 parent e4dd605 commit 1fcc26a

File tree

18 files changed

+842
-19
lines changed

18 files changed

+842
-19
lines changed

demo/guide-python/learning_to_rank.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
Getting started with learning to rank
3+
=====================================
4+
5+
.. versionadded:: 2.0.0
6+
7+
This is a demonstration of using XGBoost for learning to rank tasks using the
8+
MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
9+
`description page <https://www.microsoft.com/en-us/research/project/mslr/>`_.
10+
11+
This is a two-part demo, the first one contains a basic example of using XGBoost to
12+
train on relevance degree, and the second part simulates click data and enable the
13+
position debiasing training.
14+
15+
For an overview of learning to rank in XGBoost, please see
16+
:doc:`Learning to Rank </tutorials/learning_to_rank>`.
17+
"""
18+
from __future__ import annotations
19+
20+
import argparse
21+
import json
22+
import os
23+
import pickle as pkl
24+
25+
import numpy as np
26+
import pandas as pd
27+
from sklearn.datasets import load_svmlight_file
28+
29+
import xgboost as xgb
30+
from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
31+
32+
33+
def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
34+
"""Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
35+
36+
Returns
37+
-------
38+
39+
A list of tuples [(X, y, qid), ...].
40+
41+
"""
42+
root_path = os.path.expanduser(args.data)
43+
cacheroot_path = os.path.expanduser(args.cache)
44+
cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl")
45+
46+
# Use only the Fold1 for demo:
47+
# Train, Valid, Test
48+
# {S1,S2,S3}, S4, S5
49+
fold = 1
50+
51+
if not os.path.exists(cache_path):
52+
fold_path = os.path.join(root_path, f"Fold{fold}")
53+
train_path = os.path.join(fold_path, "train.txt")
54+
valid_path = os.path.join(fold_path, "vali.txt")
55+
test_path = os.path.join(fold_path, "test.txt")
56+
X_train, y_train, qid_train = load_svmlight_file(
57+
train_path, query_id=True, dtype=np.float32
58+
)
59+
y_train = y_train.astype(np.int32)
60+
qid_train = qid_train.astype(np.int32)
61+
62+
X_valid, y_valid, qid_valid = load_svmlight_file(
63+
valid_path, query_id=True, dtype=np.float32
64+
)
65+
y_valid = y_valid.astype(np.int32)
66+
qid_valid = qid_valid.astype(np.int32)
67+
68+
X_test, y_test, qid_test = load_svmlight_file(
69+
test_path, query_id=True, dtype=np.float32
70+
)
71+
y_test = y_test.astype(np.int32)
72+
qid_test = qid_test.astype(np.int32)
73+
74+
data = RelDataCV(
75+
train=(X_train, y_train, qid_train),
76+
test=(X_test, y_test, qid_test),
77+
max_rel=4,
78+
)
79+
80+
with open(cache_path, "wb") as fd:
81+
pkl.dump(data, fd)
82+
83+
with open(cache_path, "rb") as fd:
84+
data = pkl.load(fd)
85+
86+
return data
87+
88+
89+
def ranking_demo(args: argparse.Namespace) -> None:
90+
"""Demonstration for learning to rank with relevance degree."""
91+
data = load_mlsr_10k(args.data, args.cache)
92+
93+
# Sort data according to query index
94+
X_train, y_train, qid_train = data.train
95+
sorted_idx = np.argsort(qid_train)
96+
X_train = X_train[sorted_idx]
97+
y_train = y_train[sorted_idx]
98+
qid_train = qid_train[sorted_idx]
99+
100+
X_test, y_test, qid_test = data.test
101+
sorted_idx = np.argsort(qid_test)
102+
X_test = X_test[sorted_idx]
103+
y_test = y_test[sorted_idx]
104+
qid_test = qid_test[sorted_idx]
105+
106+
ranker = xgb.XGBRanker(
107+
tree_method="gpu_hist",
108+
lambdarank_pair_method="topk",
109+
lambdarank_num_pair_per_sample=13,
110+
eval_metric=["ndcg@1", "ndcg@8"],
111+
)
112+
ranker.fit(
113+
X_train,
114+
y_train,
115+
qid=qid_train,
116+
eval_set=[(X_test, y_test)],
117+
eval_qid=[qid_test],
118+
verbose=True,
119+
)
120+
121+
122+
def click_data_demo(args: argparse.Namespace) -> None:
123+
"""Demonstration for learning to rank with click data."""
124+
data = load_mlsr_10k(args.data, args.cache)
125+
train, test = simulate_clicks(data)
126+
assert test is not None
127+
128+
assert train.X.shape[0] == train.click.size
129+
assert test.X.shape[0] == test.click.size
130+
assert test.score.dtype == np.float32
131+
assert test.click.dtype == np.int32
132+
133+
X_train, clicks_train, y_train, qid_train = sort_ltr_samples(
134+
train.X,
135+
train.y,
136+
train.qid,
137+
train.click,
138+
train.pos,
139+
)
140+
X_test, clicks_test, y_test, qid_test = sort_ltr_samples(
141+
test.X,
142+
test.y,
143+
test.qid,
144+
test.click,
145+
test.pos,
146+
)
147+
148+
class ShowPosition(xgb.callback.TrainingCallback):
149+
def after_iteration(
150+
self,
151+
model: xgb.Booster,
152+
epoch: int,
153+
evals_log: xgb.callback.TrainingCallback.EvalsLog,
154+
) -> bool:
155+
config = json.loads(model.save_config())
156+
ti_plus = np.array(config["learner"]["objective"]["ti+"])
157+
tj_minus = np.array(config["learner"]["objective"]["tj-"])
158+
df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus})
159+
print(df)
160+
return False
161+
162+
ranker = xgb.XGBRanker(
163+
n_estimators=512,
164+
tree_method="gpu_hist",
165+
learning_rate=0.01,
166+
reg_lambda=1.5,
167+
subsample=0.8,
168+
sampling_method="gradient_based",
169+
# LTR specific parameters
170+
objective="rank:ndcg",
171+
# - Enable bias estimation
172+
lambdarank_unbiased=True,
173+
# - normalization (1 / (norm + 1))
174+
lambdarank_bias_norm=1,
175+
# - Focus on the top 12 documents
176+
lambdarank_num_pair_per_sample=12,
177+
lambdarank_pair_method="topk",
178+
ndcg_exp_gain=True,
179+
eval_metric=["ndcg@1", "ndcg@3", "ndcg@5", "ndcg@10"],
180+
callbacks=[ShowPosition()],
181+
)
182+
ranker.fit(
183+
X_train,
184+
clicks_train,
185+
qid=qid_train,
186+
eval_set=[(X_test, y_test), (X_test, clicks_test)],
187+
eval_qid=[qid_test, qid_test],
188+
verbose=True,
189+
)
190+
ranker.predict(X_test)
191+
192+
193+
if __name__ == "__main__":
194+
parser = argparse.ArgumentParser(
195+
description="Demonstration of learning to rank using XGBoost."
196+
)
197+
parser.add_argument(
198+
"--data",
199+
type=str,
200+
help="Root directory of the MSLR-WEB10K data.",
201+
required=True,
202+
)
203+
parser.add_argument(
204+
"--cache",
205+
type=str,
206+
help="Directory for caching processed data.",
207+
required=True,
208+
)
209+
args = parser.parse_args()
210+
211+
ranking_demo(args)
212+
click_data_demo(args)

doc/contrib/coding_guide.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ C++ Coding Guideline
1616
* Each line of text may contain up to 100 characters.
1717
* The use of C++ exceptions is allowed.
1818

19-
- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
19+
- Use C++17 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
2020
- Use Doxygen to document all the interface code.
21+
- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use <https://include-what-you-use.org>`_. It's not required.
22+
- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree.
2123
- We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`.
2224

2325
***********************

doc/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
2121
monotonic
2222
rf
2323
feature_interaction_constraint
24+
learning_to_rank
2425
aft_survival_analysis
2526
c_api_tutorial
2627
input_format

0 commit comments

Comments
 (0)