Skip to content

Commit 26856a8

Browse files
authored
[enc] Throw error when DMatrix is empty. (#11628)
1 parent 9311c65 commit 26856a8

File tree

13 files changed

+164
-101
lines changed

13 files changed

+164
-101
lines changed

include/xgboost/collective/result.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,6 @@ template <typename Fn>
159159
return fn();
160160
}
161161

162-
void SafeColl(Result const& rc);
162+
void SafeColl(Result const& rc, char const* file = __builtin_FILE(),
163+
std::int32_t line = __builtin_LINE());
163164
} // namespace xgboost::collective

python-package/xgboost/dask/__init__.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import logging
5656
from collections import defaultdict
5757
from contextlib import contextmanager
58-
from functools import cache, partial, update_wrapper
58+
from functools import partial, update_wrapper
5959
from threading import Thread
6060
from typing import (
6161
Any,
@@ -85,8 +85,6 @@
8585
from dask import dataframe as dd
8686
from dask.delayed import Delayed
8787
from distributed import Future
88-
from packaging.version import Version
89-
from packaging.version import parse as parse_version
9088

9189
from .. import collective, config
9290
from .._data_utils import Categories
@@ -124,7 +122,7 @@
124122
from ..tracker import RabitTracker
125123
from ..training import train as worker_train
126124
from .data import _get_dmatrices, no_group_split
127-
from .utils import get_address_from_user, get_n_threads
125+
from .utils import _DASK_2024_12_1, _DASK_2025_3_0, get_address_from_user, get_n_threads
128126

129127
_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
130128
_DataT: TypeAlias = Union[da.Array, dd.DataFrame] # do not use series as predictor
@@ -174,21 +172,6 @@
174172
LOGGER = logging.getLogger("[xgboost.dask]")
175173

176174

177-
@cache
178-
def _DASK_VERSION() -> Version:
179-
return parse_version(dask.__version__)
180-
181-
182-
@cache
183-
def _DASK_2024_12_1() -> bool:
184-
return _DASK_VERSION() >= parse_version("2024.12.1")
185-
186-
187-
@cache
188-
def _DASK_2025_3_0() -> bool:
189-
return _DASK_VERSION() >= parse_version("2025.3.0")
190-
191-
192175
def _try_start_tracker(
193176
n_workers: int,
194177
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],

python-package/xgboost/dask/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
# pylint: disable=invalid-name
12
"""Utilities for the XGBoost Dask interface."""
23

34
import logging
45
import warnings
6+
from functools import cache as fcache
57
from typing import Any, Dict, Optional, Tuple
68

9+
import dask
710
import distributed
11+
from packaging.version import Version
12+
from packaging.version import parse as parse_version
813

914
from ..collective import Config
1015

@@ -97,3 +102,18 @@ def get_address_from_user(
97102
port = coll_cfg.tracker_port
98103

99104
return host_ip, port
105+
106+
107+
@fcache
108+
def _DASK_VERSION() -> Version:
109+
return parse_version(dask.__version__)
110+
111+
112+
@fcache
113+
def _DASK_2024_12_1() -> bool:
114+
return _DASK_VERSION() >= parse_version("2024.12.1")
115+
116+
117+
@fcache
118+
def _DASK_2025_3_0() -> bool:
119+
return _DASK_VERSION() >= parse_version("2025.3.0")

python-package/xgboost/testing/dask.py

Lines changed: 72 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
# pylint: disable=invalid-name
12
"""Tests for dask shared by different test modules."""
23

3-
from typing import Any, List, Literal, Tuple, cast
4+
from typing import Any, List, Literal, Tuple, Type, cast
45

56
import numpy as np
67
import pandas as pd
78
from dask import array as da
89
from dask import dataframe as dd
9-
from distributed import Client, get_worker, wait
10+
from distributed import Client, get_worker
1011
from packaging.version import parse as parse_version
1112
from sklearn.datasets import make_classification
1213

@@ -17,7 +18,8 @@
1718

1819
from .. import dask as dxgb
1920
from .._typing import EvalsLog
20-
from ..dask import _DASK_VERSION, _get_rabit_args
21+
from ..dask import _get_rabit_args
22+
from ..dask.utils import _DASK_VERSION
2123
from .data import make_batches
2224
from .data import make_categorical as make_cat_local
2325
from .ordinal import make_recoded
@@ -325,61 +327,77 @@ def pack(**kwargs: Any) -> dd.DataFrame:
325327
# pylint: disable=too-many-locals
326328
def run_recode(client: Client, device: Device) -> None:
327329
"""Run re-coding test with the Dask interface."""
328-
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
329-
workers = get_client_workers(client)
330-
denc, dreenc, dy = (
331-
dd.from_pandas(enc, npartitions=8).persist(workers=workers),
332-
dd.from_pandas(reenc, npartitions=8).persist(workers=workers),
333-
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=workers),
334-
)
335330

336-
wait([denc, dreenc, dy])
331+
def create_dmatrix(
332+
DMatrixT: Type[dxgb.DaskDMatrix], *args: Any, **kwargs: Any
333+
) -> dxgb.DaskDMatrix:
334+
if DMatrixT is dxgb.DaskQuantileDMatrix:
335+
ref = kwargs.pop("ref", None)
336+
return DMatrixT(*args, ref=ref, **kwargs)
337337

338-
if device == "cuda":
339-
denc = denc.to_backend("cudf")
340-
dreenc = dreenc.to_backend("cudf")
341-
dy = dy.to_backend("cupy")
338+
kwargs.pop("ref", None)
339+
return DMatrixT(*args, **kwargs)
342340

343-
Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True)
344-
Xy_valid = dxgb.DaskQuantileDMatrix(
345-
client, dreenc, dy, enable_categorical=True, ref=Xy
346-
)
347-
# Base model
348-
results = dxgb.train(client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")])
341+
def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None:
342+
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
343+
to = get_client_workers(client)
349344

350-
# Training continuation
351-
Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True)
352-
Xy_valid = dxgb.DaskQuantileDMatrix(
353-
client, dreenc, dy, enable_categorical=True, ref=Xy
354-
)
355-
results_1 = dxgb.train(
356-
client,
357-
{"device": device},
358-
Xy,
359-
evals=[(Xy_valid, "Valid")],
360-
xgb_model=results["booster"],
361-
)
345+
denc, dreenc, dy = (
346+
dd.from_pandas(enc, npartitions=8).persist(workers=to),
347+
dd.from_pandas(reenc, npartitions=8).persist(workers=to),
348+
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=to),
349+
)
362350

363-
# Reversed training continuation
364-
Xy = dxgb.DaskQuantileDMatrix(client, dreenc, dy, enable_categorical=True)
365-
Xy_valid = dxgb.DaskQuantileDMatrix(
366-
client, denc, dy, enable_categorical=True, ref=Xy
367-
)
368-
results_2 = dxgb.train(
369-
client,
370-
{"device": device},
371-
Xy,
372-
evals=[(Xy_valid, "Valid")],
373-
xgb_model=results["booster"],
374-
)
375-
np.testing.assert_allclose(
376-
results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"]
377-
)
351+
if device == "cuda":
352+
denc = denc.to_backend("cudf")
353+
dreenc = dreenc.to_backend("cudf")
354+
dy = dy.to_backend("cupy")
355+
356+
Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)
357+
Xy_valid = create_dmatrix(
358+
DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy
359+
)
360+
# Base model
361+
results = dxgb.train(
362+
client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")]
363+
)
364+
365+
# Training continuation
366+
Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)
367+
Xy_valid = create_dmatrix(
368+
DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy
369+
)
370+
results_1 = dxgb.train(
371+
client,
372+
{"device": device},
373+
Xy,
374+
evals=[(Xy_valid, "Valid")],
375+
xgb_model=results["booster"],
376+
)
377+
378+
# Reversed training continuation
379+
Xy = create_dmatrix(DMatrixT, client, dreenc, dy, enable_categorical=True)
380+
Xy_valid = create_dmatrix(
381+
DMatrixT, client, denc, dy, enable_categorical=True, ref=Xy
382+
)
383+
results_2 = dxgb.train(
384+
client,
385+
{"device": device},
386+
Xy,
387+
evals=[(Xy_valid, "Valid")],
388+
xgb_model=results["booster"],
389+
)
390+
np.testing.assert_allclose(
391+
results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"]
392+
)
393+
394+
predt_0 = dxgb.inplace_predict(client, results, denc).compute()
395+
predt_1 = dxgb.inplace_predict(client, results, dreenc).compute()
396+
assert_allclose(device, predt_0, predt_1)
378397

379-
predt_0 = dxgb.inplace_predict(client, results, denc).compute()
380-
predt_1 = dxgb.inplace_predict(client, results, dreenc).compute()
381-
assert_allclose(device, predt_0, predt_1)
398+
predt_0 = dxgb.predict(client, results, Xy).compute()
399+
predt_1 = dxgb.predict(client, results, Xy_valid).compute()
400+
assert_allclose(device, predt_0, predt_1)
382401

383-
predt_0 = dxgb.predict(client, results, Xy).compute()
384-
predt_1 = dxgb.predict(client, results, Xy_valid).compute()
385-
assert_allclose(device, predt_0, predt_1)
402+
for DMatrixT in [dxgb.DaskDMatrix, dxgb.DaskQuantileDMatrix]:
403+
run(DMatrixT)

src/collective/result.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2024, XGBoost Contributors
2+
* Copyright 2024-2025, XGBoost Contributors
33
*/
44
#include "xgboost/collective/result.h"
55

@@ -65,17 +65,27 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
6565
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
6666
dmlc::DateLogger logger;
6767
if (file && line != -1) {
68-
auto name = std::filesystem::path{ file }.filename();
68+
auto name = std::filesystem::path{file}.filename();
6969
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
7070
"]: " + std::forward<std::string>(msg);
7171
}
7272
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
7373
}
7474
} // namespace detail
7575

76-
void SafeColl(Result const& rc) {
77-
if (!rc.OK()) {
78-
LOG(FATAL) << rc.Report();
76+
void SafeColl(Result const& rc, char const* file, std::int32_t line) {
77+
if (rc.OK()) {
78+
return;
7979
}
80+
if (file && line != -1) {
81+
dmlc::DateLogger logger;
82+
auto name = std::filesystem::path{file}.filename();
83+
LOG(FATAL) << ("[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
84+
"]:\n")
85+
<< rc.Report();
86+
// Return just in case if this function is deep in ctypes callbacks.
87+
return;
88+
}
89+
LOG(FATAL) << rc.Report();
8090
}
8191
} // namespace xgboost::collective

src/data/cat_container.cc

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include <utility> // for move
1010
#include <vector> // for vector
1111

12-
#include "../common/error_msg.h" // for NoFloatCat
13-
#include "../encoder/types.h" // for Overloaded
14-
#include "xgboost/json.h" // for Json
12+
#include "../collective/allreduce.h" // for Allreduce
13+
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize
14+
#include "../common/error_msg.h" // for NoFloatCat
15+
#include "../encoder/types.h" // for Overloaded
16+
#include "xgboost/json.h" // for Json
1517

1618
namespace xgboost {
1719
CatContainer::CatContainer(enc::HostColumnsView const& df, bool is_ref) : CatContainer{} {
@@ -293,4 +295,22 @@ void CatContainer::Sort(Context const* ctx) {
293295
enc::SortNames(enc::Policy<EncErrorPolicy>{}, view, this->sorted_idx_.HostSpan());
294296
}
295297
#endif // !defined(XGBOOST_USE_CUDA)
298+
299+
void SyncCategories(Context const* ctx, CatContainer* cats, bool is_empty) {
300+
CHECK(cats);
301+
if (!collective::IsDistributed()) {
302+
return;
303+
}
304+
305+
auto rank = collective::GetRank();
306+
std::vector<std::int32_t> workers(collective::GetWorldSize(), 0);
307+
workers[rank] = is_empty;
308+
collective::SafeColl(collective::Allreduce(ctx, &workers, collective::Op::kSum));
309+
if (cats->HasCategorical() &&
310+
std::any_of(workers.cbegin(), workers.cend(), [](auto v) { return v == 1; })) {
311+
LOG(FATAL)
312+
<< "A worker cannot have empty input when a dataframe with categorical features is used. "
313+
"XGBoost cannot infer the categories if the input is empty.";
314+
}
315+
}
296316
} // namespace xgboost

src/data/cat_container.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class CatContainer {
162162
* this method returns True.
163163
*/
164164
[[nodiscard]] bool Empty() const;
165-
[[nodiscard]] bool NeedRecode() const { return !this->Empty() && !this->is_ref_; }
165+
[[nodiscard]] bool NeedRecode() const { return this->HasCategorical() && !this->is_ref_; }
166166

167167
[[nodiscard]] std::size_t NumFeatures() const;
168168
/**
@@ -263,6 +263,8 @@ struct NoOpAccessor {
263263
[[nodiscard]] XGBOOST_DEVICE float operator()(Entry const& e) const { return e.fvalue; }
264264
};
265265

266+
void SyncCategories(Context const* ctx, CatContainer* cats, bool is_empty);
267+
266268
namespace cpu_impl {
267269
inline auto MakeCatAccessor(Context const* ctx, enc::HostColumnsView const& new_enc,
268270
CatContainer const* orig_cats) {

src/data/extmem_quantile_dmatrix.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix
4949
}
5050
this->batch_ = p;
5151
this->fmat_ctx_ = ctx;
52+
53+
SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);
5254
}
5355

5456
ExtMemQuantileDMatrix::~ExtMemQuantileDMatrix() {

src/data/iterative_dmatrix.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
#include <utility> // for move
1010
#include <vector> // for vector
1111

12-
#include "../common/categorical.h" // common::IsCat
12+
#include "../common/categorical.h" // for IsCat
1313
#include "../common/hist_util.h" // for HistogramCuts
1414
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
1515
#include "batch_utils.h" // for RegenGHist
16+
#include "cat_container.h" // for SyncCategories
1617
#include "gradient_index.h" // for GHistIndexMatrix
1718
#include "proxy_dmatrix.h" // for DataIterProxy, DispatchAny
1819
#include "quantile_dmatrix.h" // for GetCutsFromRef
@@ -50,6 +51,8 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
5051
this->fmat_ctx_ = ctx;
5152
this->batch_ = p;
5253

54+
SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);
55+
5356
LOG(INFO) << "Finished constructing the `IterativeDMatrix`: (" << this->Info().num_row_ << ", "
5457
<< this->Info().num_col_ << ", " << this->info_.num_nonzero_ << ").";
5558
}

src/data/simple_dmatrix.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
341341
}
342342
info_.num_nonzero_ = data_vec.size();
343343

344+
SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);
345+
344346
// Sort the index for row partitioners used by variuos tree methods.
345347
if (!sparse_page_->IsIndicesSorted(ctx.Threads())) {
346348
sparse_page_->SortIndices(ctx.Threads());

0 commit comments

Comments
 (0)