File tree 2 files changed +11
-10
lines changed
2 files changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -1967,7 +1967,7 @@ def fit(
1967
1967
provide qid.
1968
1968
qid :
1969
1969
Query ID for each training sample. Should have the size of n_samples. If
1970
- this is set to None, then user must provide group.
1970
+ this is set to None, then user must provide group or a special column in X .
1971
1971
sample_weight :
1972
1972
Query group weights
1973
1973
@@ -1988,7 +1988,8 @@ def fit(
1988
1988
query groups in the ``i``-th pair in **eval_set**.
1989
1989
eval_qid :
1990
1990
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
1991
- pair in **eval_set**.
1991
+ pair in **eval_set**. The special column convention in `X` applies to
1992
+ validation datasets as well.
1992
1993
1993
1994
eval_metric : str, list of str, optional
1994
1995
.. deprecated:: 1.6.0
@@ -2031,15 +2032,7 @@ def fit(
2031
2032
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
2032
2033
2033
2034
"""
2034
- # check if group information is provided
2035
-
2036
2035
with config_context (verbosity = self .verbosity ):
2037
- if eval_set is not None :
2038
- if eval_group is None and eval_qid is None :
2039
- raise ValueError (
2040
- "eval_group or eval_qid is required if eval_set is not None"
2041
- )
2042
-
2043
2036
train_dmatrix , evals = _wrap_evaluation_matrices (
2044
2037
missing = self .missing ,
2045
2038
X = X ,
Original file line number Diff line number Diff line change @@ -217,6 +217,11 @@ def test_ranking_qid_df():
217
217
s = ranker .score (df , y )
218
218
assert s > 0.7
219
219
220
+ # works with validation datasets as well
221
+ valid_df = df .copy ()
222
+ valid_df .iloc [0 , 0 ] = 3.0
223
+ ranker .fit (df , y , eval_set = [(valid_df , y )])
224
+
220
225
# same as passing qid directly
221
226
ranker = xgb .XGBRanker (n_estimators = 3 , eval_metric = "ndcg" )
222
227
ranker .fit (X , y , qid = q )
@@ -238,6 +243,9 @@ def test_ranking_qid_df():
238
243
results = cross_val_score (ranker , df , y )
239
244
assert len (results ) == 5
240
245
246
+ with pytest .raises (ValueError , match = "Either `group` or `qid`." ):
247
+ ranker .fit (df , y , eval_set = [(X , y )])
248
+
241
249
242
250
def test_stacking_regression ():
243
251
from sklearn .datasets import load_diabetes
You can’t perform that action at this time.
0 commit comments