Skip to content

Commit aabe787

Browse files
committed
test check.
1 parent 290027c commit aabe787

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

python-package/xgboost/sklearn.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -1967,7 +1967,7 @@ def fit(
19671967
provide qid.
19681968
qid :
19691969
Query ID for each training sample. Should have the size of n_samples. If
1970-
this is set to None, then user must provide group.
1970+
this is set to None, then user must provide group or a special column in X.
19711971
sample_weight :
19721972
Query group weights
19731973
@@ -1988,7 +1988,8 @@ def fit(
19881988
query groups in the ``i``-th pair in **eval_set**.
19891989
eval_qid :
19901990
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
1991-
pair in **eval_set**.
1991+
pair in **eval_set**. The special column convention in `X` applies to
1992+
validation datasets as well.
19921993
19931994
eval_metric : str, list of str, optional
19941995
.. deprecated:: 1.6.0
@@ -2031,15 +2032,7 @@ def fit(
20312032
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
20322033
20332034
"""
2034-
# check if group information is provided
2035-
20362035
with config_context(verbosity=self.verbosity):
2037-
if eval_set is not None:
2038-
if eval_group is None and eval_qid is None:
2039-
raise ValueError(
2040-
"eval_group or eval_qid is required if eval_set is not None"
2041-
)
2042-
20432036
train_dmatrix, evals = _wrap_evaluation_matrices(
20442037
missing=self.missing,
20452038
X=X,

tests/python/test_with_sklearn.py

+8
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ def test_ranking_qid_df():
217217
s = ranker.score(df, y)
218218
assert s > 0.7
219219

220+
# works with validation datasets as well
221+
valid_df = df.copy()
222+
valid_df.iloc[0, 0] = 3.0
223+
ranker.fit(df, y, eval_set=[(valid_df, y)])
224+
220225
# same as passing qid directly
221226
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg")
222227
ranker.fit(X, y, qid=q)
@@ -238,6 +243,9 @@ def test_ranking_qid_df():
238243
results = cross_val_score(ranker, df, y)
239244
assert len(results) == 5
240245

246+
with pytest.raises(ValueError, match="Either `group` or `qid`."):
247+
ranker.fit(df, y, eval_set=[(X, y)])
248+
241249

242250
def test_stacking_regression():
243251
from sklearn.datasets import load_diabetes

0 commit comments

Comments
 (0)