Skip to content

Commit 9311c65

Browse files
authored
Support leaf prediction with QDM on CPU. (#11620)
1 parent 0d47374 commit 9311c65

File tree

11 files changed

+459
-348
lines changed

11 files changed

+459
-348
lines changed

python-package/xgboost/testing/dask.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,18 @@ def pack(**kwargs: Any) -> dd.DataFrame:
322322
return X, y
323323

324324

325+
# pylint: disable=too-many-locals
325326
def run_recode(client: Client, device: Device) -> None:
326327
"""Run re-coding test with the Dask interface."""
327328
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
329+
workers = get_client_workers(client)
328330
denc, dreenc, dy = (
329-
dd.from_pandas(enc, npartitions=8),
330-
dd.from_pandas(reenc, npartitions=8),
331-
da.from_array(y, chunks=(y.shape[0] // 8,)),
331+
dd.from_pandas(enc, npartitions=8).persist(workers=workers),
332+
dd.from_pandas(reenc, npartitions=8).persist(workers=workers),
333+
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=workers),
332334
)
335+
333336
wait([denc, dreenc, dy])
334-
client.rebalance([denc, dreenc, dy])
335337

336338
if device == "cuda":
337339
denc = denc.to_backend("cudf")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Tests for inference."""
2+
3+
from typing import Type
4+
5+
import numpy as np
6+
7+
from ..core import DMatrix
8+
from ..training import train
9+
from .shared import validate_leaf_output
10+
from .utils import Device
11+
12+
13+
# pylint: disable=invalid-name,too-many-locals
14+
def run_predict_leaf(device: Device, DMatrixT: Type[DMatrix]) -> np.ndarray:
15+
"""Run tests for leaf index prediction."""
16+
rows = 100
17+
cols = 4
18+
classes = 5
19+
num_parallel_tree = 4
20+
num_boost_round = 10
21+
rng = np.random.RandomState(1994)
22+
X = rng.randn(rows, cols)
23+
y = rng.randint(low=0, high=classes, size=rows)
24+
25+
m = DMatrixT(X, y)
26+
booster = train(
27+
{
28+
"num_parallel_tree": num_parallel_tree,
29+
"num_class": classes,
30+
"tree_method": "hist",
31+
},
32+
m,
33+
num_boost_round=num_boost_round,
34+
)
35+
36+
booster.set_param({"device": device})
37+
empty = DMatrixT(np.ones(shape=(0, cols)))
38+
empty_leaf = booster.predict(empty, pred_leaf=True)
39+
assert empty_leaf.shape[0] == 0
40+
41+
leaf = booster.predict(m, pred_leaf=True, strict_shape=True)
42+
assert leaf.shape[0] == rows
43+
assert leaf.shape[1] == num_boost_round
44+
assert leaf.shape[2] == classes
45+
assert leaf.shape[3] == num_parallel_tree
46+
47+
validate_leaf_output(leaf, num_parallel_tree)
48+
49+
n_iters = np.int32(2)
50+
sliced = booster.predict(
51+
m,
52+
pred_leaf=True,
53+
iteration_range=(0, n_iters),
54+
strict_shape=True,
55+
)
56+
first = sliced[0, ...]
57+
58+
assert np.prod(first.shape) == classes * num_parallel_tree * n_iters
59+
60+
# When there's only 1 tree, the output is a 1 dim vector
61+
booster = train({"tree_method": "hist"}, num_boost_round=1, dtrain=m)
62+
booster.set_param({"device": device})
63+
assert booster.predict(m, pred_leaf=True).shape == (rows,)
64+
65+
return leaf

src/common/threading_utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,17 @@ void ParallelFor(Index size, std::int32_t n_threads, Func&& fn) {
251251
ParallelFor(size, n_threads, Sched::Static(), std::forward<Func>(fn));
252252
}
253253

254+
/**
255+
* @brief 1-d block-based parallel for loop.
256+
*
257+
* @tparam kBlockOfRowsSize The size of the block.
258+
* @tparam Index The type of the index.
259+
* @tparam Func The type of the function.
260+
*
261+
* @param size The size of the range.
262+
* @param n_threads The number of threads.
263+
* @param fn The function to execute. The function should take a Range1d as an argument.
264+
*/
254265
template <std::size_t kBlockOfRowsSize, typename Index, typename Func>
255266
void ParallelFor1d(Index size, std::int32_t n_threads, Func&& fn) {
256267
static_assert(std::is_void_v<std::invoke_result_t<Func, common::Range1d>>);

src/data/ellpack_page.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
462462
info{CalcNumSymbols(
463463
ctx,
464464
[&] {
465+
if (page.Size() == 0) {
466+
return static_cast<typename decltype(page.row_ptr)::value_type>(0);
467+
}
468+
CHECK_GE(page.row_ptr.size(), 2);
465469
auto it = common::MakeIndexTransformIter(
466470
[&](bst_idx_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
467471
return *std::max_element(it, it + page.Size());

src/data/gradient_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "../common/ref_resource_view.h" // for RefResourceView
2121
#include "../common/threading_utils.h"
2222
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
23-
#include "adapter.h"
23+
#include "entry.h" // for IsValidFunctor
2424
#include "xgboost/base.h"
2525
#include "xgboost/data.h"
2626

src/data/proxy_dmatrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct ExternalDataInfo {
158158

159159
CHECK_GE(this->n_features, 1) << "Data must has at least 1 column.";
160160
CHECK_EQ(this->base_rowids.size(), this->n_batches + 1);
161+
CHECK_LE(this->row_stride, this->n_features);
161162
}
162163

163164
void SetInfo(Context const* ctx, bool sync, MetaInfo* p_info) {

0 commit comments

Comments
 (0)