Skip to content

Commit 0883cb4

Browse files
committed
Reimplement lambdamart ndcg.
* Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Reimplement lambdamart ndcg. * Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Start looking into unbiased. lambda. Extract the ndcg cache. header. cleanup namespace. small check. namespace. init with param. gain. extract. groups. Cleanup. disable. debug. remove. Revert "remove." This reverts commit ea025f9. sigmoid. cleanup. metric name. check scores. note. check map. extract utilities. avoid inline. fix. header. extract more. note. note. note. start working on map. fix. continue map. map. matrix. Remove map. note. format. move check. cleanup. use cached discount, use double. cleanup. Add position to the Python interface. pass it into lambda. Full ratio. rank. comment. some work on GPU. compile. move cache initialization. descending. Fix arg sort. basic ndcg score. metric weight. config. extract. pass position again. Define a metric decorator. position. decorate metric.. return. note. irrelevant docs. fix weights. header. Share the bias. Use position check info. use cache for param. note. prepare to work on deterministic gpu. rounding. Extract op. cleanup. Use it. check label. ditch launchn. rounding. Move rounding into cache. fix check label. GPU fixes. Irrelevant doc. try to avoid inf. mad. Work on metric cache. Cleanup sort. use cache. cache others. revert. add test for metric. fixes. msg. note. remove reduce by key. comments. check position. stream. min. small cleanup. use atomic for now. fill. no inline. norm. remove op. start gpu. cleanup. use gpu for update. segmented reduce. revert. comments. comments. fix. comments. fix bounds. comments. cache. pointer. fixes. no spark. revert. Cleanup. cleanup. work on gain type. fix. notes. make metric name. remove. revert. revert. comment. revert. Move back into rank metric. Set name in objective. fix. Don't configure. note. merge tests. accept empty group. fixes. float. revert and fix. not mutable. prototype for cache. extract. convert to DMatrix. cache. Extract the cache. Port changes. fix & cleanup. cleanup. cleanup. Rename. restore. remove. header. revert. rename. rename. doc. cleanup. doc. cleanup. tests. tests. split up. jvm parameters. doc. Fix. Use cache in cox. Revert "Use cache in cox." This reverts commit e1cec37. Remove pairwise. iwyu. rename. Move. Merge. ranking utils. Fixes. rename. Comments. todos. Small cleanup. doc. Start working on demo. move some code here. rename. Update doc. Update doc. Work on demo. work on demo. demo. Demo. Specify the max rel degree. remove position. Fix. Work on demo. demo. Using only one fold. cache. demo. schema. comments. Lint. fix test. automake. macos. schema. test. schema. lint. fix tests. Implement MAP and pair sampling. revert sorting. Work on ranknet. remove. Don't upgrade cost if larger than. Extract GPU make pairs. error message. Remove. Cleanup some gpu tests. Move. Move NDCG test. fix weights. Move rest of the tests. Remove. Work on tests. fixes. Cleanup. header. cleanup. Update document. update document. fix build. cpplint. rename. Fixes and cleanup. Cleanup tests. lint. fix tests. debug macos non-openmp checks. macos. fix ndcg test. Ensure number of threads is smaller than the number of inputs. fix. Debug macos. fixes. Add weight normalization. Note on reproducible result. Don't normalize if it's binary. old ctk. Use old objective. Update doc. Convert pyspark tests. black. Fix rebase. Fix rebase. Start looking into CV. Hacky score function. extract parsing. Cleanup and tests. Lint & note. test check. Update document. Update tests & doc. Support custom metric as well. c++-17. cleanup old metrics. rename. Fixes. Fix cxx test. test cudf. start converting tests. pylint. fix data load. Cleanup the tests. Parameter tests. isort. Fix test. Specify src path for isort. 17 goodies. Fix rebase. Start working on ranking cache tests. Extract CPU impl. test debiasing. use index. ranking cache. comment. some work on debiasing. save the estimated bias. normalize by default. GPU norm. fix gpu unbiased. cleanup. cleanup. Remove workaround. Default to topk. Restore. Cleanup. Revert change in algorithm. norm. Move data generation process in testing for reuse. Move sort samples as well. cleanup. Generate data. lint. pylint. Fix. Fix spark test. avoid sampling with unbiased. Cleanup demo. Handle single group simulation. Numeric issues. More numeric issues. sigma. naming. Simple test. tests. brief description. Revert "brief description." This reverts commit 0b3817a. rebase. symbol. Rebase. disable normalization. Revert "disable normalization." This reverts commit ef3133d2b4a76714f3514808c6e2ae5937e6a8c2.
1 parent a84a1fd commit 0883cb4

File tree

39 files changed

+1424
-1178
lines changed

39 files changed

+1424
-1178
lines changed

R-package/src/Makevars.in

-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ OBJECTS= \
3232
$(PKGROOT)/src/objective/objective.o \
3333
$(PKGROOT)/src/objective/regression_obj.o \
3434
$(PKGROOT)/src/objective/multiclass_obj.o \
35-
$(PKGROOT)/src/objective/rank_obj.o \
3635
$(PKGROOT)/src/objective/lambdarank_obj.o \
3736
$(PKGROOT)/src/objective/hinge.o \
3837
$(PKGROOT)/src/objective/aft_obj.o \

R-package/src/Makevars.win

-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ OBJECTS= \
3232
$(PKGROOT)/src/objective/objective.o \
3333
$(PKGROOT)/src/objective/regression_obj.o \
3434
$(PKGROOT)/src/objective/multiclass_obj.o \
35-
$(PKGROOT)/src/objective/rank_obj.o \
3635
$(PKGROOT)/src/objective/lambdarank_obj.o \
3736
$(PKGROOT)/src/objective/hinge.o \
3837
$(PKGROOT)/src/objective/aft_obj.o \

demo/guide-python/learning_to_rank.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Getting started with learning to rank
3+
=====================================
4+
5+
.. versionadded:: 2.0.0
6+
7+
This is a demonstration of using XGBoost for learning to rank tasks using the
8+
MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
9+
`description page <https://www.microsoft.com/en-us/research/project/mslr/>`_.
10+
11+
This is a two-part demo, the first one contains a basic example of using XGBoost to
12+
train on relevance degree, and the second part simulates click data and enable the
13+
position debiasing training.
14+
15+
For an overview of learning to rank in XGBoost, please see
16+
:doc:`Learning to Rank </tutorials/learning_to_rank>`.
17+
"""
18+
from __future__ import annotations
19+
20+
import argparse
21+
import json
22+
import os
23+
import pickle as pkl
24+
25+
import numpy as np
26+
import pandas as pd
27+
from sklearn.datasets import load_svmlight_file
28+
29+
import xgboost as xgb
30+
from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
31+
32+
33+
def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
34+
"""Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
35+
36+
Returns
37+
-------
38+
39+
A list of tuples [(X, y, qid), ...].
40+
41+
"""
42+
root_path = os.path.expanduser(args.data)
43+
cacheroot_path = os.path.expanduser(args.cache)
44+
cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl")
45+
46+
# Use only the Fold1 for demo:
47+
# Train, Valid, Test
48+
# {S1,S2,S3}, S4, S5
49+
fold = 1
50+
51+
if not os.path.exists(cache_path):
52+
fold_path = os.path.join(root_path, f"Fold{fold}")
53+
train_path = os.path.join(fold_path, "train.txt")
54+
valid_path = os.path.join(fold_path, "vali.txt")
55+
test_path = os.path.join(fold_path, "test.txt")
56+
X_train, y_train, qid_train = load_svmlight_file(
57+
train_path, query_id=True, dtype=np.float32
58+
)
59+
y_train = y_train.astype(np.int32)
60+
qid_train = qid_train.astype(np.int32)
61+
62+
X_valid, y_valid, qid_valid = load_svmlight_file(
63+
valid_path, query_id=True, dtype=np.float32
64+
)
65+
y_valid = y_valid.astype(np.int32)
66+
qid_valid = qid_valid.astype(np.int32)
67+
68+
X_test, y_test, qid_test = load_svmlight_file(
69+
test_path, query_id=True, dtype=np.float32
70+
)
71+
y_test = y_test.astype(np.int32)
72+
qid_test = qid_test.astype(np.int32)
73+
74+
data = RelDataCV(
75+
train=(X_train, y_train, qid_train),
76+
test=(X_test, y_test, qid_test),
77+
max_rel=4,
78+
)
79+
80+
with open(cache_path, "wb") as fd:
81+
pkl.dump(data, fd)
82+
83+
with open(cache_path, "rb") as fd:
84+
data = pkl.load(fd)
85+
86+
return data
87+
88+
89+
def ranking_demo(args: argparse.Namespace) -> None:
90+
"""Demonstration for learning to rank with relevance degree."""
91+
data = load_mlsr_10k(args.data, args.cache)
92+
93+
X_train, y_train, qid_train = data.train
94+
sorted_idx = np.argsort(qid_train)
95+
X_train = X_train[sorted_idx]
96+
y_train = y_train[sorted_idx]
97+
qid_train = qid_train[sorted_idx]
98+
99+
X_test, y_test, qid_test = data.test
100+
sorted_idx = np.argsort(qid_test)
101+
X_test = X_test[sorted_idx]
102+
y_test = y_test[sorted_idx]
103+
qid_test = qid_test[sorted_idx]
104+
105+
ranker = xgb.XGBRanker(
106+
tree_method="gpu_hist",
107+
lambdarank_pair_method="topk",
108+
lambdarank_num_pair_per_sample=13,
109+
eval_metric=["ndcg@1", "ndcg@8"],
110+
)
111+
ranker.fit(
112+
X_train,
113+
y_train,
114+
qid=qid_train,
115+
eval_set=[(X_test, y_test)],
116+
eval_qid=[qid_test],
117+
verbose=True,
118+
)
119+
120+
121+
def click_data_demo(args: argparse.Namespace) -> None:
122+
"""Demonstration for learning to rank with click data."""
123+
data = load_mlsr_10k(args.data, args.cache)
124+
folds = simulate_clicks(data)
125+
126+
train = [pack[0] for pack in folds]
127+
test = [pack[1] for pack in folds]
128+
129+
X_train, y_train, qid_train, scores_train, clicks_train, position_train = train
130+
assert X_train.shape[0] == clicks_train.size
131+
X_test, y_test, qid_test, scores_test, clicks_test, position_test = test
132+
assert X_test.shape[0] == clicks_test.size
133+
assert scores_test.dtype == np.float32
134+
assert clicks_test.dtype == np.int32
135+
136+
X_train, clicks_train, y_train, qid_train = sort_ltr_samples(
137+
X_train,
138+
y_train,
139+
qid_train,
140+
clicks_train,
141+
position_train,
142+
)
143+
X_test, clicks_test, y_test, qid_test = sort_ltr_samples(
144+
X_test,
145+
y_test,
146+
qid_test,
147+
clicks_test,
148+
position_test,
149+
)
150+
151+
class ShowPosition(xgb.callback.TrainingCallback):
152+
def after_iteration(self, model, epoch, evals_log) -> bool:
153+
config = json.loads(model.save_config())
154+
ti_plus = np.array(config["learner"]["objective"]["ti+"])
155+
tj_minus = np.array(config["learner"]["objective"]["tj-"])
156+
df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus})
157+
print(df)
158+
return False
159+
160+
ranker = xgb.XGBRanker(
161+
n_estimators=512,
162+
tree_method="gpu_hist",
163+
learning_rate=0.01,
164+
reg_lambda=1.5,
165+
subsample=0.8,
166+
sampling_method="gradient_based",
167+
# LTR specific parameters
168+
objective="rank:ndcg",
169+
# - Enable bias estimation
170+
lambdarank_unbiased=True,
171+
# - normalization (1 / (norm + 1))
172+
lambdarank_bias_norm=1,
173+
# - Focus on the top 12 documents
174+
lambdarank_num_pair_per_sample=12,
175+
lambdarank_pair_method="topk",
176+
ndcg_exp_gain=True,
177+
eval_metric=["ndcg@1", "ndcg@3", "ndcg@5", "ndcg@10"],
178+
callbacks=[ShowPosition()],
179+
)
180+
ranker.fit(
181+
X_train,
182+
clicks_train,
183+
qid=qid_train,
184+
eval_set=[(X_test, y_test), (X_test, clicks_test)],
185+
eval_qid=[qid_test, qid_test],
186+
verbose=True,
187+
)
188+
ranker.predict(X_test)
189+
190+
191+
if __name__ == "__main__":
192+
parser = argparse.ArgumentParser(
193+
description="Demonstration of learning to rank using XGBoost."
194+
)
195+
parser.add_argument(
196+
"--data",
197+
type=str,
198+
help="Root directory of the MSLR-WEB10K data.",
199+
required=True,
200+
)
201+
parser.add_argument(
202+
"--cache",
203+
type=str,
204+
help="Directory for caching processed data.",
205+
required=True,
206+
)
207+
args = parser.parse_args()
208+
209+
ranking_demo(args)
210+
click_data_demo(args)

doc/contrib/coding_guide.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ C++ Coding Guideline
1616
* Each line of text may contain up to 100 characters.
1717
* The use of C++ exceptions is allowed.
1818

19-
- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
19+
- Use C++14 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
2020
- Use Doxygen to document all the interface code.
21+
- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use <https://include-what-you-use.org>`_. It's not required.
22+
- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree.
2123
- We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`.
2224

2325
***********************

doc/model.schema

+14-4
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@
219219
"num_pairsample": { "type": "string" },
220220
"fix_list_weight": { "type": "string" }
221221
}
222+
},
223+
"lambdarank_param": {
224+
"type": "object",
225+
"properties": {
226+
"lambdarank_num_pair_per_sample": { "type": "string" },
227+
"lambdarank_pair_method": { "type": "string" },
228+
"lambdarank_unbiased": {"type": "string" },
229+
"lambdarank_bias_norm": {"type": "string" },
230+
"ndcg_exp_gain": {"type": "string"}
231+
}
222232
}
223233
},
224234
"type": "object",
@@ -477,22 +487,22 @@
477487
"type": "object",
478488
"properties": {
479489
"name": { "const": "rank:pairwise" },
480-
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
490+
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
481491
},
482492
"required": [
483493
"name",
484-
"lambda_rank_param"
494+
"lambdarank_param"
485495
]
486496
},
487497
{
488498
"type": "object",
489499
"properties": {
490500
"name": { "const": "rank:ndcg" },
491-
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
501+
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
492502
},
493503
"required": [
494504
"name",
495-
"lambda_rank_param"
505+
"lambdarank_param"
496506
]
497507
},
498508
{

doc/parameter.rst

+37-6
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ Parameters for Tree Booster
233233
.. note:: This parameter is working-in-progress.
234234

235235
- The strategy used for training multi-target models, including multi-target regression
236-
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
236+
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
237237

238238
- ``one_output_per_tree``: One model for each target.
239239
- ``multi_output_tree``: Use multi-target trees.
@@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
380380
See :doc:`/tutorials/aft_survival_analysis` for details.
381381
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
382382
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
383-
- ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
384-
- ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized
385-
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
383+
- ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized. This objective supports position debiasing for click data.
384+
- ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
385+
- ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
386386
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
387387
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
388388

@@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
395395

396396
* ``eval_metric`` [default according to objective]
397397

398-
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
399-
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
398+
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
399+
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
400+
400401
- The choices are listed below:
401402

402403
- ``rmse``: `root mean square error <http://en.wikipedia.org/wiki/Root_mean_square_error>`_
@@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
480481

481482
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
482483

484+
.. _ltr-param:
485+
486+
Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
487+
================================================================================
488+
489+
These are parameters specific to learning to rank task. See :doc:`Learning to Rank </tutorials/learning_to_rank>` for an in-depth explanation.
490+
491+
* ``lambdarank_pair_method`` [default = ``mean``]
492+
493+
How to construct pairs for pair-wise learning.
494+
495+
- ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
496+
- ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
497+
498+
* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
499+
500+
It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
501+
502+
* ``lambdarank_unbiased`` [default = ``false``]
503+
504+
Specify whether do we need to debias input click data.
505+
506+
* ``lambdarank_bias_norm`` [default = 2.0]
507+
508+
:math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
509+
510+
* ``ndcg_exp_gain`` [default = ``true``]
511+
512+
Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
513+
483514
***********************
484515
Command Line Parameters
485516
***********************

doc/tutorials/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
2121
monotonic
2222
rf
2323
feature_interaction_constraint
24+
learning_to_rank
2425
aft_survival_analysis
2526
c_api_tutorial
2627
input_format

0 commit comments

Comments
 (0)