Skip to content

Commit f25f74d

Browse files
authored
[enc][dask] Support training continuation. (#11609)
1 parent aea640d commit f25f74d

40 files changed

+612
-325
lines changed

demo/guide-python/cat_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
training and inference. There are many ways to attain the same goal, this script can be
77
used as a starting point.
88
9+
.. versionchanged:: 3.1
10+
11+
Start with 3.1, users don't need this for most of the cases. See :ref:`cat-recode`
12+
for more info.
13+
914
See Also
1015
--------
1116
- :doc:`Tutorial </tutorials/categorical>`

doc/python/python_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,14 @@ Collective
206206

207207
.. autofunction:: xgboost.collective.init
208208

209+
.. autofunction:: xgboost.collective.finalize
210+
211+
.. autofunction:: xgboost.collective.get_rank
212+
213+
.. autofunction:: xgboost.collective.get_world_size
214+
215+
.. autoclass:: xgboost.collective.CommunicatorContext
216+
209217
.. automodule:: xgboost.tracker
210218

211219
.. autoclass:: xgboost.tracker.RabitTracker

doc/tutorials/categorical.rst

Lines changed: 87 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,38 +137,109 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
137137
:class:`dask.Array <dask.Array>` can also be used for categorical data. Lastly, the
138138
sklearn interface :py:class:`~xgboost.XGBRegressor` has the same parameter.
139139

140-
****************
141-
Data Consistency
142-
****************
140+
.. _cat-recode:
143141

144-
XGBoost accepts parameters to indicate which feature is considered categorical, either through the ``dtypes`` of a dataframe or through the ``feature_types`` parameter. However, XGBoost by itself doesn't store information on how categories are encoded in the first place. For instance, given an encoding schema that maps music genres to integer codes:
142+
********************************
143+
Auto-recoding (Data Consistency)
144+
********************************
145+
146+
.. versionchanged:: 3.1
147+
148+
Starting with XGBoost 3.1, the *Python* interface can perform automatic re-coding for
149+
new inputs.
150+
151+
XGBoost accepts parameters to indicate which feature is considered categorical, either
152+
through the ``dtypes`` of a dataframe or through the ``feature_types`` parameter. However,
153+
except for the Python interface, XGBoost doesn't store the information about how
154+
categories are encoded in the first place. For instance, given an encoding schema that
155+
maps music genres to integer codes:
145156

146157
.. code-block:: python
147158
148159
{"acoustic": 0, "indie": 1, "blues": 2, "country": 3}
149160
150-
XGBoost doesn't know this mapping from the input and hence cannot store it in the model. The mapping usually happens in the users' data engineering pipeline with column transformers like :py:class:`sklearn.preprocessing.OrdinalEncoder`. To make sure correct result from XGBoost, users need to keep the pipeline for transforming data consistent across training and testing data. One should watch out for errors like:
161+
Aside from the Python interface (R/Java/C, etc), XGBoost doesn't know this mapping from
162+
the input and hence cannot store it in the model. The mapping usually happens in the
163+
users' data engineering pipeline. To ensure the correct result from XGBoost, users need to
164+
keep the pipeline for transforming data consistent across training and testing data.
165+
166+
Starting with 3.1, the *Python* interface can remember the encoding and perform recoding
167+
during inference and training continuation when the input is a dataframe (`pandas`,
168+
`cuDF`, `polars`, `pyarrow`, `modin`). The feature support focuses on basic usage. It has
169+
some restrictions on the types of inputs that can be accepted. First, category names
170+
must have one of the following types:
171+
172+
- string
173+
- integer, from 8-bit to 64-bit, both signed and unsigned are supported.
174+
- 32-bit or 64-bit floating point
175+
176+
Other category types are not supported. Second, the input types must be strictly
177+
consistent. For example, XGBoost will raise an error if the categorical columns in the
178+
training set are unsigned integers whereas the test dataset has signed integer columns. If
179+
you have categories that are not one of the supported types, you need to perform the
180+
re-coding using a pre-processing data transformer like the
181+
:py:class:`sklearn.preprocessing.OrdinalEncoder`. See
182+
:ref:`sphx_glr_python_examples_cat_pipeline.py` for a worked example using an ordinal
183+
encoder. To clarify, the type here refers to the type of the name of categories (called
184+
``Index`` in pandas):
185+
186+
.. code-block:: python
187+
188+
# string type
189+
{"acoustic": 0, "indie": 1, "blues": 2, "country": 3}
190+
# integer type
191+
{-1: 0, 1: 1, 3: 2, 7: 3}
192+
# depending on the dataframe implementation, it can be signed or unsigned.
193+
{5: 0, 1: 1, 3: 2, 7: 3}
194+
# floating point type, both 32-bit and 64-bit are supported.
195+
{-1.0: 0, 1.0: 1, 3.0: 2, 7.0: 3}
196+
197+
Internally, XGBoost attempts to extract the categories from the dataframe inputs. For
198+
inference (predict), the re-coding happens on the fly and there's no data copy (baring
199+
some internal transformations performed by the dataframe itself). For training
200+
continuation however, re-coding requires some extra steps if you are using the native
201+
interface. The sklearn interface and the Dask interface can handle training continuation
202+
automatically. Last, please note that using the re-coder with the native interface is
203+
still experimental. It's ready for testing, but we want to observe the feature usage for a
204+
period of time and might make some breaking changes if needed. The following is a snippet
205+
of using the native interface:
151206

152207
.. code-block:: python
153208
154-
X_train["genre"] = X_train["genre"].astype("category")
155-
reg = xgb.XGBRegressor(enable_categorical=True).fit(X_train, y_train)
209+
import pandas as pd
210+
211+
X = pd.DataFrame()
212+
Xy = xgboost.QuantileDMatrix(X, y, enable_categorical=True)
213+
booster = xgboost.train({}, Xy)
214+
215+
# XGBoost can handle re-coding for inference without user intervention
216+
X_new = pd.DataFrame()
217+
booster.inplace_predict(X_new)
218+
219+
# Get categories saved in the model for training continuation
220+
categories = booster.get_categories()
221+
# Use saved categories as a reference for re-coding.
222+
# Training continuation requires a re-coded DMatrix, pass the categories as feature_types
223+
Xy_new = xgboost.QuantileDMatrix(
224+
X_new, y_new, feature_types=categories, enable_categorical=True, ref=Xy
225+
)
226+
booster_1 = xgboost.train({}, Xy_new, xgb_model=booster)
156227
157-
# invalid encoding
158-
X_test["genre"] = X_test["genre"].astype("category")
159-
reg.predict(X_test)
160228
161-
In the above snippet, training data and test data are encoded separately, resulting in two different encoding schemas and invalid prediction result. See :ref:`sphx_glr_python_examples_cat_pipeline.py` for a worked example using ordinal encoder.
229+
No extra step is required for using the scikit-learn interface as long as the inputs are
230+
dataframes. During training continuation, XGBoost will either extract the categories from
231+
the previous model or use the categories from the new training dataset if the input model
232+
doesn't have the information.
162233

163234
*************
164235
Miscellaneous
165236
*************
166237

167-
By default, XGBoost assumes input categories are integers starting from 0 till the number
168-
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
169-
values due to mistakes or missing values in training dataset. It can be negative value,
170-
integer values that can not be accurately represented by 32-bit floating point, or values
171-
that are larger than actual number of unique categories. During training this is
238+
By default, XGBoost assumes input category codes are integers starting from 0 till the
239+
number of categories :math:`[0, n\_categories)`. However, user might provide inputs with
240+
invalid values due to mistakes or missing values in training dataset. It can be negative
241+
value, integer values that can not be accurately represented by 32-bit floating point, or
242+
values that are larger than actual number of unique categories. During training this is
172243
validated but for prediction it's treated as the same as not-chosen category for
173244
performance reasons.
174245

python-package/xgboost/_data_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,11 @@ def __del__(self) -> None:
683683
def get_ref_categories(
684684
feature_types: Optional[Union[FeatureTypes, Categories]],
685685
) -> Tuple[Optional[FeatureTypes], Optional[Categories]]:
686-
"""Get the optional reference categories from the input."""
686+
"""Get the optional reference categories from the `feature_types`. This is used by
687+
various `DMatrix` where the `feature_types` is reused for specifying the reference
688+
categories.
689+
690+
"""
687691
if isinstance(feature_types, Categories):
688692
ref_categories = feature_types
689693
feature_types = None

python-package/xgboost/collective.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class Config:
3737
See `dmlc_timeout` in :py:meth:`init`. This is only used for communicators, not
3838
the tracker. They are different parameters since the timeout for tracker limits
3939
only the time for starting and finalizing the communication group, whereas the
40-
timeout for communicators limits the time used for collective operations.
40+
timeout for communicators limits the time used for collective operations, like
41+
:py:meth:`allreduce`.
4142
4243
tracker_host_ip : See :py:class:`~xgboost.tracker.RabitTracker`.
4344
@@ -94,7 +95,8 @@ def init(**args: _ArgVals) -> None:
9495
- federated_client_cert: Client certificate file path. Only needed for the SSL
9596
mode.
9697
97-
Use upper case for environment variables, use lower case for runtime configuration.
98+
Use upper case for environment variables, use lower case for runtime
99+
configuration.
98100
99101
"""
100102
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
@@ -122,17 +124,17 @@ def get_world_size() -> int:
122124
123125
Returns
124126
-------
125-
n : int
127+
n :
126128
Total number of process.
127129
"""
128130
ret = _LIB.XGCommunicatorGetWorldSize()
129131
return ret
130132

131133

132-
def is_distributed() -> int:
134+
def is_distributed() -> bool:
133135
"""If the collective communicator is distributed."""
134136
is_dist = _LIB.XGCommunicatorIsDistributed()
135-
return is_dist
137+
return bool(is_dist)
136138

137139

138140
def communicator_print(msg: Any) -> None:
@@ -160,8 +162,8 @@ def get_processor_name() -> str:
160162
161163
Returns
162164
-------
163-
name : str
164-
the name of processor(host)
165+
name :
166+
The name of processor(host)
165167
"""
166168
name_str = ctypes.c_char_p()
167169
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))

python-package/xgboost/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,12 +1361,12 @@ def get_categories(self, export_to_arrow: bool = False) -> Categories:
13611361
13621362
.. warning::
13631363
1364-
This function is still working in progress.
1364+
This function is experimental.
13651365
13661366
Parameters
13671367
----------
13681368
export_to_arrow :
1369-
The returned container will contain a list to ``pyarrow`` arrays for the
1369+
The returned container will contain a list of ``pyarrow`` arrays for the
13701370
categories. See the :py:meth:`~Categories.to_arrow` for more info.
13711371
13721372
"""

python-package/xgboost/dask/__init__.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
from packaging.version import parse as parse_version
9090

9191
from .. import collective, config
92+
from .._data_utils import Categories
9293
from .._typing import FeatureNames, FeatureTypes, IterationRange
9394
from ..callback import TrainingCallback
9495
from ..collective import Config as CollConfig
@@ -122,7 +123,7 @@
122123
)
123124
from ..tracker import RabitTracker
124125
from ..training import train as worker_train
125-
from .data import _create_dmatrix, _create_quantile_dmatrix, no_group_split
126+
from .data import _get_dmatrices, no_group_split
126127
from .utils import get_address_from_user, get_n_threads
127128

128129
_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
@@ -331,6 +332,10 @@ def __init__(
331332

332333
self.feature_names = feature_names
333334
self.feature_types = feature_types
335+
if isinstance(feature_types, Categories):
336+
raise TypeError(
337+
"The Dask interface can handle categories from DataFrame automatically."
338+
)
334339
self.missing = missing if missing is not None else numpy.nan
335340
self.enable_categorical = enable_categorical
336341

@@ -652,12 +657,6 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
652657
return args
653658

654659

655-
def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
656-
if is_quantile:
657-
return _create_quantile_dmatrix(**kwargs)
658-
return _create_dmatrix(**kwargs)
659-
660-
661660
async def _get_rabit_args(
662661
client: "distributed.Client",
663662
n_workers: int,
@@ -735,37 +734,6 @@ async def _check_workers_are_alive(
735734
raise RuntimeError(f"Missing required workers: {missing_workers}")
736735

737736

738-
def _get_dmatrices(
739-
train_ref: dict,
740-
train_id: int,
741-
*refs: dict,
742-
evals_id: Sequence[int],
743-
evals_name: Sequence[str],
744-
n_threads: int,
745-
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
746-
# Create training DMatrix
747-
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
748-
# Create evaluation DMatrices
749-
evals: List[Tuple[DMatrix, str]] = []
750-
for i, ref in enumerate(refs):
751-
# Same DMatrix as the training
752-
if evals_id[i] == train_id:
753-
evals.append((Xy, evals_name[i]))
754-
continue
755-
if ref.get("ref", None) is not None:
756-
if ref["ref"] != train_id:
757-
raise ValueError(
758-
"The training DMatrix should be used as a reference to evaluation"
759-
" `QuantileDMatrix`."
760-
)
761-
del ref["ref"]
762-
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
763-
else:
764-
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
765-
evals.append((eval_Xy, evals_name[i]))
766-
return Xy, evals
767-
768-
769737
async def _train_async(
770738
*,
771739
client: "distributed.Client",
@@ -817,6 +785,8 @@ def do_train( # pylint: disable=too-many-positional-arguments
817785
evals_id=evals_id,
818786
evals_name=evals_name,
819787
n_threads=n_threads,
788+
# We need the model for reference categories.
789+
model=xgb_model,
820790
)
821791

822792
booster = worker_train(
@@ -1934,7 +1904,7 @@ class DaskXGBRanker(XGBRankerMixIn, DaskScikitLearnBase):
19341904
def __init__(
19351905
self,
19361906
*,
1937-
objective: str = "rank:pairwise",
1907+
objective: str = "rank:ndcg",
19381908
allow_group_split: bool = False,
19391909
coll_cfg: Optional[CollConfig] = None,
19401910
**kwargs: Any,
@@ -2051,8 +2021,8 @@ def check_ser(
20512021
) -> TypeGuard[Optional[dd.Series]]:
20522022
if not isinstance(qid, dd.Series) and qid is not None:
20532023
raise TypeError(
2054-
f"When `allow_group_split` is set to False, {name} is required to be"
2055-
" a series."
2024+
f"When `allow_group_split` is set to False, {name} is required to "
2025+
"be a series."
20562026
)
20572027
return True
20582028

0 commit comments

Comments
 (0)