Skip to content

Commit 495553b

Browse files
authored
[mt] Initial work on vector leaf for the GPU predictor. (#11729)
1 parent 5877016 commit 495553b

File tree

5 files changed

+158
-78
lines changed

5 files changed

+158
-78
lines changed

src/predictor/cpu_predictor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ struct LaunchConfig : public Args... {
460460
}
461461
} else {
462462
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
463-
// bool any_missing = !page.IsDense();
464463
fn(SparsePageView{page.GetView(), page.base_rowid, acc});
465464
}
466465
}

src/predictor/gpu_predictor.cu

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "utils.h" // for CheckProxyDMatrix
2828
#include "xgboost/data.h"
2929
#include "xgboost/host_device_vector.h"
30+
#include "xgboost/multi_target_tree_model.h" // for MultiTargetTree, MultiTargetTreeView
3031
#include "xgboost/predictor.h"
3132
#include "xgboost/tree_model.h"
3233
#include "xgboost/tree_updater.h"
@@ -243,6 +244,55 @@ struct DeviceAdapterLoader {
243244
}
244245
};
245246

247+
namespace multi {
248+
template <bool has_missing, bool has_categorical>
249+
XGBOOST_DEVICE bst_node_t GetNextNode(MultiTargetTreeView const& tree, bst_node_t const nidx,
250+
float fvalue, bool is_missing) {
251+
if (has_missing && is_missing) {
252+
return tree.DefaultChild(nidx);
253+
} else {
254+
return fvalue < tree.SplitCond(nidx) ? tree.LeftChild(nidx) : tree.RightChild(nidx);
255+
}
256+
}
257+
258+
template <bool has_missing, bool has_categorical, typename Loader>
259+
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, MultiTargetTreeView const& tree,
260+
Loader* loader) {
261+
bst_node_t nidx = 0;
262+
while (!tree.IsLeaf(nidx)) {
263+
float fvalue = loader->GetElement(ridx, tree.SplitIndex(nidx));
264+
bool is_missing = common::CheckNAN(fvalue);
265+
auto next = GetNextNode<has_missing, has_categorical>(tree, nidx, fvalue, is_missing);
266+
assert(nidx < next);
267+
nidx = next;
268+
}
269+
return nidx;
270+
}
271+
272+
template <bool has_missing, typename Loader>
273+
__device__ auto GetLeafWeight(bst_idx_t ridx, MultiTargetTreeView const& tree, Loader* loader) {
274+
bst_node_t nidx = GetLeafIndex<has_missing, false>(ridx, tree, loader);
275+
return tree.LeafValue(nidx);
276+
}
277+
278+
template <typename Loader, typename Data, bool has_missing, typename EncAccessor>
279+
__global__ void PredictKernel(Data data, common::Span<MultiTargetTreeView> trees, bool use_shared,
280+
float missing, linalg::MatrixView<float> d_out_predt,
281+
EncAccessor acc) {
282+
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), data.NumRows())) {
283+
Loader loader{std::move(data), use_shared, static_cast<bst_feature_t>(data.NumCols()),
284+
data.NumRows(), missing, std::move(acc)};
285+
for (auto const& tree : trees) {
286+
auto leaf = GetLeafWeight<has_missing>(idx, tree, &loader);
287+
for (std::size_t i = 0, n = leaf.Shape(0); i < n; ++i) {
288+
d_out_predt(idx, i) += leaf(i);
289+
}
290+
}
291+
}
292+
}
293+
} // namespace multi
294+
295+
namespace scalar {
246296
template <bool has_missing, bool has_categorical, typename Loader>
247297
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader* loader) {
248298
bst_node_t nidx = 0;
@@ -257,8 +307,7 @@ __device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader*
257307
}
258308

259309
template <bool has_missing, typename Loader>
260-
__device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree,
261-
Loader *loader) {
310+
__device__ float GetLeafWeight(bst_idx_t ridx, TreeView const& tree, Loader* loader) {
262311
bst_node_t nidx = -1;
263312
if (tree.HasCategoricalSplit()) {
264313
nidx = GetLeafIndex<has_missing, true>(ridx, tree, loader);
@@ -267,6 +316,7 @@ __device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree,
267316
}
268317
return tree.d_tree[nidx].LeafValue();
269318
}
319+
} // namespace scalar
270320

271321
template <typename Loader, typename Data, bool has_missing, typename EncAccessor>
272322
__global__ void
@@ -295,9 +345,9 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
295345

296346
bst_node_t leaf = -1;
297347
if (d_tree.HasCategoricalSplit()) {
298-
leaf = GetLeafIndex<has_missing, true>(ridx, d_tree, &loader);
348+
leaf = scalar::GetLeafIndex<has_missing, true>(ridx, d_tree, &loader);
299349
} else {
300-
leaf = GetLeafIndex<has_missing, false>(ridx, d_tree, &loader);
350+
leaf = scalar::GetLeafIndex<has_missing, false>(ridx, d_tree, &loader);
301351
}
302352
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
303353
}
@@ -313,7 +363,7 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
313363
common::Span<uint32_t const> d_cat_tree_segments,
314364
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
315365
common::Span<uint32_t const> d_categories, bst_tree_t tree_begin,
316-
bst_tree_t tree_end, bst_feature_t num_features, size_t num_rows,
366+
bst_tree_t tree_end, bst_feature_t num_features, bst_idx_t num_rows,
317367
bool use_shared, int num_group, float missing, EncAccessor acc) {
318368
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
319369
Loader loader{std::move(data), use_shared, num_features, num_rows, missing, std::move(acc)};
@@ -326,20 +376,19 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
326376
tree_begin, tree_idx, d_nodes,
327377
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
328378
d_cat_node_segments, d_categories};
329-
float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
379+
float leaf = scalar::GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
330380
sum += leaf;
331381
}
332382
d_out_predictions[global_idx] += sum;
333383
} else {
334384
for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
335385
int tree_group = d_tree_group[tree_idx];
336-
TreeView d_tree{
337-
tree_begin, tree_idx, d_nodes,
338-
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
339-
d_cat_node_segments, d_categories};
386+
TreeView d_tree{tree_begin, tree_idx, d_nodes,
387+
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
388+
d_cat_node_segments, d_categories};
340389
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
341390
d_out_predictions[out_prediction_idx] +=
342-
GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
391+
scalar::GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
343392
}
344393
}
345394
}
@@ -400,12 +449,12 @@ class DeviceModel {
400449
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
401450
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
402451
auto& src_stats = model.trees.at(tree_idx)->GetStats();
403-
dh::safe_cuda(cudaMemcpyAsync(
404-
d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data(),
405-
sizeof(RegTree::Node) * src_nodes.size(), cudaMemcpyDefault));
406-
dh::safe_cuda(cudaMemcpyAsync(
407-
d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data(),
408-
sizeof(RTreeNodeStat) * src_stats.size(), cudaMemcpyDefault));
452+
dh::safe_cuda(cudaMemcpyAsync(d_nodes + h_tree_segments[tree_idx - tree_begin],
453+
src_nodes.data(), sizeof(RegTree::Node) * src_nodes.size(),
454+
cudaMemcpyDefault));
455+
dh::safe_cuda(cudaMemcpyAsync(d_stats + h_tree_segments[tree_idx - tree_begin],
456+
src_stats.data(), sizeof(RTreeNodeStat) * src_stats.size(),
457+
cudaMemcpyDefault));
409458
}
410459

411460
tree_group = HostDeviceVector<int>(model.tree_info.size(), 0, device);
@@ -424,14 +473,13 @@ class DeviceModel {
424473

425474
categories = HostDeviceVector<uint32_t>({}, device);
426475
categories_tree_segments = HostDeviceVector<uint32_t>(1, 0, device);
427-
std::vector<uint32_t> &h_categories = categories.HostVector();
428-
std::vector<uint32_t> &h_split_cat_segments = categories_tree_segments.HostVector();
476+
std::vector<uint32_t>& h_categories = categories.HostVector();
477+
std::vector<uint32_t>& h_split_cat_segments = categories_tree_segments.HostVector();
429478
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
430479
auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories();
431480
size_t orig_size = h_categories.size();
432481
h_categories.resize(orig_size + src_cats.size());
433-
std::copy(src_cats.cbegin(), src_cats.cend(),
434-
h_categories.begin() + orig_size);
482+
std::copy(src_cats.cbegin(), src_cats.cend(), h_categories.begin() + orig_size);
435483
h_split_cat_segments.push_back(h_categories.size());
436484
}
437485

@@ -974,7 +1022,7 @@ class LaunchConfig {
9741022
void LaunchPredict(Context const* ctx, Data data, float missing, bst_idx_t n_samples,
9751023
bst_feature_t n_features, DeviceModel const& model, bool is_dense,
9761024
enc::DeviceColumnsView const& new_enc, bst_idx_t batch_offset,
977-
HostDeviceVector<bst_float>* predictions) const {
1025+
HostDeviceVector<float>* predictions) const {
9781026
LaunchPredictKernel(ctx, is_dense, new_enc, model, [&](auto is_dense, auto&& acc) {
9791027
constexpr bool kHasMissing = !std::is_same_v<decltype(is_dense), std::true_type>;
9801028
using EncAccessor = std::remove_reference_t<decltype(acc)>;
@@ -993,10 +1041,30 @@ class LaunchConfig {
9931041
});
9941042
}
9951043

1044+
template <template <typename> typename Loader, typename Data>
1045+
void LaunchMultiPredict(Context const* ctx, Data data, gbm::GBTreeModel const& model,
1046+
float missing, bst_tree_t tree_begin, bst_tree_t tree_end,
1047+
bst_idx_t batch_offset, HostDeviceVector<float>* predictions) const {
1048+
CHECK_EQ(batch_offset, 0); // External memory is not supported yet.
1049+
CHECK_GT(tree_end, tree_begin);
1050+
std::vector<MultiTargetTreeView> h_trees;
1051+
for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
1052+
h_trees.emplace_back(model.trees[tree_idx]->GetMultiTargetTree()->View(ctx));
1053+
}
1054+
dh::device_vector<MultiTargetTreeView> trees = h_trees;
1055+
CHECK_GE(predictions->Size(), data.NumRows() * h_trees.front().NumTargets());
1056+
auto kernel = multi::PredictKernel<Loader<NoOpAccessor>, Data, true, NoOpAccessor>;
1057+
auto predt =
1058+
linalg::MakeTensorView(ctx, predictions, data.NumRows(), h_trees.front().NumTargets());
1059+
this->Grid(data.NumRows())
1060+
.LaunchImpl(std::move(kernel), std::move(data), dh::ToSpan(trees), this->UseShared(),
1061+
missing, predt, NoOpAccessor{});
1062+
}
1063+
9961064
template <template <typename> typename Loader, typename Data>
9971065
void LaunchLeaf(Context const* ctx, Data data, bst_idx_t n_samples, bst_feature_t n_features,
9981066
DeviceModel const& model, bool is_dense, enc::DeviceColumnsView const& new_enc,
999-
bst_idx_t batch_offset, HostDeviceVector<bst_float>* predictions) const {
1067+
bst_idx_t batch_offset, HostDeviceVector<float>* predictions) const {
10001068
LaunchPredictKernel(ctx, is_dense, new_enc, model, [&](auto is_dense, auto&& acc) {
10011069
constexpr bool kHasMissing = !std::is_same_v<decltype(is_dense), std::true_type>;
10021070
using EncAccessor = std::remove_reference_t<decltype(acc)>;
@@ -1037,7 +1105,9 @@ class GPUPredictor : public xgboost::Predictor {
10371105
out_preds->SetDevice(ctx_->Device());
10381106
auto const& info = p_fmat->Info();
10391107
DeviceModel d_model;
1040-
d_model.Init(model, tree_begin, tree_end, ctx_->Device());
1108+
if (!model.trees[tree_begin]->IsMultiTarget()) {
1109+
d_model.Init(model, tree_begin, tree_end, ctx_->Device());
1110+
}
10411111

10421112
if (info.IsColumnSplit()) {
10431113
column_split_helper_.PredictBatch(p_fmat, out_preds, model, d_model);
@@ -1056,9 +1126,15 @@ class GPUPredictor : public xgboost::Predictor {
10561126
auto n_features = model.learner_model_param->num_feature;
10571127
LaunchConfig cfg{ctx_, n_features};
10581128
SparsePageView data(page.data.DeviceSpan(), page.offset.DeviceSpan(), n_features);
1059-
cfg.LaunchPredict<SparsePageLoader>(
1060-
this->ctx_, std::move(data), std::numeric_limits<float>::quiet_NaN(), page.Size(),
1061-
n_features, d_model, p_fmat->IsDense(), new_enc, batch_offset, out_preds);
1129+
if (model.trees[tree_begin]->IsMultiTarget()) {
1130+
cfg.LaunchMultiPredict<SparsePageLoader>(this->ctx_, std::move(data), model,
1131+
std::numeric_limits<float>::quiet_NaN(),
1132+
tree_begin, tree_end, batch_offset, out_preds);
1133+
} else {
1134+
cfg.LaunchPredict<SparsePageLoader>(
1135+
this->ctx_, std::move(data), std::numeric_limits<float>::quiet_NaN(), page.Size(),
1136+
n_features, d_model, p_fmat->IsDense(), new_enc, batch_offset, out_preds);
1137+
}
10621138
batch_offset += page.Size() * model.learner_model_param->OutputLength();
10631139
}
10641140
} else {
@@ -1158,7 +1234,7 @@ class GPUPredictor : public xgboost::Predictor {
11581234

11591235
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
11601236
const gbm::GBTreeModel& model, bst_tree_t tree_end,
1161-
std::vector<bst_float> const* tree_weights, bool approximate, int,
1237+
std::vector<float> const* tree_weights, bool approximate, int,
11621238
unsigned) const override {
11631239
StringView not_implemented{
11641240
"contribution is not implemented in the GPU predictor, use CPU instead."};
@@ -1177,8 +1253,7 @@ class GPUPredictor : public xgboost::Predictor {
11771253
const int ngroup = model.learner_model_param->num_output_group;
11781254
CHECK_NE(ngroup, 0);
11791255
// allocate space for (number of features + bias) times the number of rows
1180-
size_t contributions_columns =
1181-
model.learner_model_param->num_feature + 1; // +1 for bias
1256+
size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias
11821257
auto dim_size = contributions_columns * model.learner_model_param->num_output_group;
11831258
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
11841259
out_contribs->Fill(0.0f);
@@ -1245,8 +1320,8 @@ class GPUPredictor : public xgboost::Predictor {
12451320
gbm::GBTreeModel const& model, bst_tree_t tree_end,
12461321
std::vector<float> const* tree_weights,
12471322
bool approximate) const override {
1248-
std::string not_implemented{"contribution is not implemented in GPU "
1249-
"predictor, use `cpu_predictor` instead."};
1323+
std::string not_implemented{
1324+
"contribution is not implemented in GPU predictor, use cpu instead."};
12501325
if (approximate) {
12511326
LOG(FATAL) << "Approximated " << not_implemented;
12521327
}
@@ -1333,7 +1408,6 @@ class GPUPredictor : public xgboost::Predictor {
13331408
gbm::GBTreeModel const& model, bst_tree_t tree_end) const override {
13341409
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
13351410

1336-
13371411
const MetaInfo& info = p_fmat->Info();
13381412
bst_idx_t n_samples = info.num_row_;
13391413
tree_end = GetTreeLimit(model.trees, tree_end);

tests/cpp/predictor/test_cpu_predictor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ TEST(CpuPredictor, SparseColumnSplit) {
279279

280280
TEST(CpuPredictor, Multi) {
281281
Context ctx;
282-
ctx.nthread = 1;
283282
TestVectorLeafPrediction(&ctx);
284283
}
285284

tests/cpp/predictor/test_gpu_predictor.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ TEST(GPUPredictor, PredictLeafBasic) {
351351
}
352352
}
353353

354+
TEST(GPUPredictor, Multi) {
355+
auto ctx = MakeCUDACtx(0);
356+
TestVectorLeafPrediction(&ctx);
357+
}
358+
354359
TEST(GPUPredictor, Sparse) {
355360
auto ctx = MakeCUDACtx(0);
356361
TestSparsePrediction(&ctx, 0.2);

0 commit comments

Comments
 (0)