Skip to content

Commit 73261fe

Browse files
authored
[mt] Add a new gradient type for the GPU hist. (#11798)
- Use reduced gradient for tree structure exploration. - Expose the new gradient type through the objective interface.
1 parent ca7230f commit 73261fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1006
-304
lines changed

include/xgboost/gbm.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
#include <dmlc/registry.h>
1212
#include <xgboost/base.h>
1313
#include <xgboost/data.h>
14+
#include <xgboost/gradient.h> // for GradientContainer
1415
#include <xgboost/host_device_vector.h>
1516
#include <xgboost/model.h>
1617

17-
#include <vector>
18-
#include <string>
1918
#include <functional>
2019
#include <memory>
20+
#include <string>
21+
#include <vector>
2122

2223
namespace xgboost {
2324

@@ -78,8 +79,8 @@ class GradientBooster : public Model, public Configurable {
7879
* the booster may change content of gpair
7980
* @param obj The objective function used for boosting.
8081
*/
81-
virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
82-
PredictionCacheEntry*, ObjFunction const* obj) = 0;
82+
virtual void DoBoost(DMatrix* p_fmat, GradientContainer* in_gpair,
83+
PredictionCacheEntry* prediction, ObjFunction const* obj) = 0;
8384

8485
/**
8586
* \brief Generate predictions for given feature matrix

include/xgboost/gradient.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* Copyright 2025, XGBoost Contributors
3+
*/
4+
#pragma once
5+
6+
#include <xgboost/base.h> // for GradientPair
7+
#include <xgboost/linalg.h> // for Matrix
8+
#include <xgboost/logging.h>
9+
10+
#include <cstddef> // for size_t
11+
12+
namespace xgboost {
13+
/**
14+
* @brief Container for gradient produced by objective.
15+
*/
16+
struct GradientContainer {
17+
/** @brief Gradient used for multi-target tree split and linear model. */
18+
linalg::Matrix<GradientPair> gpair;
19+
/** @brief Gradient used for tree leaf value, optional. */
20+
linalg::Matrix<GradientPair> value_gpair;
21+
22+
[[nodiscard]] bool HasValueGrad() const noexcept { return !value_gpair.Empty(); }
23+
24+
[[nodiscard]] std::size_t NumSplitTargets() const noexcept { return gpair.Shape(1); }
25+
[[nodiscard]] std::size_t NumTargets() const noexcept {
26+
return HasValueGrad() ? value_gpair.Shape(1) : this->gpair.Shape(1);
27+
}
28+
29+
linalg::MatrixView<GradientPair const> ValueGrad(Context const* ctx) const {
30+
if (HasValueGrad()) {
31+
return this->value_gpair.View(ctx->Device());
32+
}
33+
return this->gpair.View(ctx->Device());
34+
}
35+
36+
[[nodiscard]] linalg::Matrix<GradientPair> const* Grad() const { return &gpair; }
37+
[[nodiscard]] linalg::Matrix<GradientPair>* Grad() { return &gpair; }
38+
39+
[[nodiscard]] linalg::Matrix<GradientPair> const* FullGradOnly() const {
40+
if (this->HasValueGrad()) {
41+
LOG(FATAL) << "Reduced gradient is not yet supported.";
42+
}
43+
return this->Grad();
44+
}
45+
[[nodiscard]] linalg::Matrix<GradientPair>* FullGradOnly() {
46+
if (this->HasValueGrad()) {
47+
LOG(FATAL) << "Reduced gradient is not yet supported.";
48+
}
49+
return this->Grad();
50+
}
51+
};
52+
} // namespace xgboost

include/xgboost/learner.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,23 @@
88
#ifndef XGBOOST_LEARNER_H_
99
#define XGBOOST_LEARNER_H_
1010

11-
#include <dmlc/io.h> // for Serializable
12-
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
13-
#include <xgboost/context.h> // for Context
14-
#include <xgboost/linalg.h> // for Vector, VectorView
15-
#include <xgboost/metric.h> // for Metric
16-
#include <xgboost/model.h> // for Configurable, Model
17-
#include <xgboost/span.h> // for Span
18-
#include <xgboost/task.h> // for ObjInfo
11+
#include <dmlc/io.h> // for Serializable
12+
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
13+
#include <xgboost/context.h> // for Context
14+
#include <xgboost/gradient.h> // for GradientContainer
15+
#include <xgboost/linalg.h> // for Vector, VectorView
16+
#include <xgboost/metric.h> // for Metric
17+
#include <xgboost/model.h> // for Configurable, Model
18+
#include <xgboost/span.h> // for Span
19+
#include <xgboost/task.h> // for ObjInfo
1920

20-
#include <algorithm> // for max
21-
#include <cstdint> // for int32_t, uint32_t, uint8_t
22-
#include <map> // for map
23-
#include <memory> // for shared_ptr, unique_ptr
24-
#include <string> // for string
25-
#include <utility> // for move
26-
#include <vector> // for vector
21+
#include <algorithm> // for max
22+
#include <cstdint> // for int32_t, uint32_t, uint8_t
23+
#include <map> // for map
24+
#include <memory> // for shared_ptr, unique_ptr
25+
#include <string> // for string
26+
#include <utility> // for move
27+
#include <vector> // for vector
2728

2829
namespace xgboost {
2930
class FeatureMap;
@@ -47,25 +48,24 @@ enum class PredictionType : std::uint8_t { // NOLINT
4748
kLeaf = 6
4849
};
4950

50-
/*!
51-
* \brief Learner class that does training and prediction.
51+
/**
52+
* @brief Learner class that does training and prediction.
5253
* This is the user facing module of xgboost training.
5354
* The Load/Save function corresponds to the model used in python/R.
54-
* \code
55+
* @code
5556
*
56-
* std::unique_ptr<Learner> learner(new Learner::Create(cache_mats));
57-
* learner.Configure(configs);
57+
* std::unique_ptr<Learner> learner{Learner::Create(cache_mats)};
58+
* learner->Configure(configs);
5859
*
5960
* for (int iter = 0; iter < max_iter; ++iter) {
6061
* learner->UpdateOneIter(iter, train_mat);
6162
* LOG(INFO) << learner->EvalOneIter(iter, data_sets, data_names);
6263
* }
6364
*
64-
* \endcode
65+
* @endcode
6566
*/
6667
class Learner : public Model, public Configurable, public dmlc::Serializable {
6768
public:
68-
/*! \brief virtual destructor */
6969
~Learner() override;
7070
/*!
7171
* \brief Configure Learner based on set parameters.
@@ -88,7 +88,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
8888
* @param in_gpair The input gradient statistics.
8989
*/
9090
virtual void BoostOneIter(std::int32_t iter, std::shared_ptr<DMatrix> train,
91-
linalg::Matrix<GradientPair>* in_gpair) = 0;
91+
GradientContainer* in_gpair) = 0;
9292
/*!
9393
* \brief evaluate the model for specific iteration using the configured metrics.
9494
* \param iter iteration number

include/xgboost/linalg.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ template <typename T>
957957
using Vector = Tensor<T, 1>;
958958

959959
/**
960-
* \brief Create an array without initialization.
960+
* @brief Create an array without initialization.
961961
*/
962962
template <typename T, typename... Index>
963963
auto Empty(Context const *ctx, Index &&...index) {
@@ -967,6 +967,17 @@ auto Empty(Context const *ctx, Index &&...index) {
967967
return t;
968968
}
969969

970+
/**
971+
* @brief Create an array with the same shape and dtype as the input.
972+
*/
973+
template <typename T, std::int32_t kDim>
974+
auto EmptyLike(Context const *ctx, Tensor<T, kDim> const &in) {
975+
Tensor<T, kDim> t;
976+
t.SetDevice(ctx->Device());
977+
t.Reshape(in.Shape());
978+
return t;
979+
}
980+
970981
/**
971982
* \brief Create an array with value v.
972983
*/

include/xgboost/multi_target_tree_model.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,18 @@ class MultiTargetTree : public Model {
6060
MultiTargetTree& operator=(MultiTargetTree&& that) = delete;
6161

6262
/**
63-
* @brief Set the weight for a leaf.
63+
* @brief Set the weight for the root.
6464
*/
65-
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
65+
void SetRoot(linalg::VectorView<float const> weight);
6666
/**
6767
* @brief Expand a leaf into split node.
6868
*/
6969
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
7070
linalg::VectorView<float const> base_weight,
7171
linalg::VectorView<float const> left_weight,
7272
linalg::VectorView<float const> right_weight);
73+
/** @see RegTree::SetLeaves */
74+
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
7375

7476
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
7577
return left_.ConstHostVector()[nidx] == InvalidNodeId();

include/xgboost/objective.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class ObjFunction : public Configurable {
129129
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
130130
MetaInfo const& /*info*/, float /*learning_rate*/,
131131
HostDeviceVector<float> const& /*prediction*/,
132-
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
132+
bst_target_t /*group_idx*/, RegTree* /*p_tree*/) const {}
133133
/**
134134
* @brief Create an objective function according to the name.
135135
*

include/xgboost/tree_model.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,23 @@ class RegTree : public Model {
321321
float right_sum,
322322
bst_node_t leaf_right_child = kInvalidNodeId);
323323
/**
324-
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
324+
* @brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
325325
*/
326326
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
327327
linalg::VectorView<float const> base_weight,
328328
linalg::VectorView<float const> left_weight,
329329
linalg::VectorView<float const> right_weight);
330+
/**
331+
* @brief Set all leaf weights for a multi-target tree.
332+
*
333+
* The leaf weight can be different from the internal weight stored by @ref ExpandNode
334+
* This function is used to set the leaf at the end of tree construction.
335+
*
336+
* @param leaves The node indices for all leaves. This must contain all the leaves in this tree.
337+
* @param weights Row-major matrix for leaf weights, each row contains a leaf specified by the
338+
* leaves parameter.
339+
*/
340+
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
330341

331342
/**
332343
* \brief Expands a leaf node with categories
@@ -396,11 +407,11 @@ class RegTree : public Model {
396407
*/
397408
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
398409
/**
399-
* @brief Set the leaf weight for a multi-target tree.
410+
* @brief Set the root weight for a multi-target tree.
400411
*/
401-
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
412+
void SetRoot(linalg::VectorView<float const> weight) {
402413
CHECK(IsMultiTarget());
403-
return this->p_mt_tree_->SetLeaf(nidx, weight);
414+
return this->p_mt_tree_->SetRoot(weight);
404415
}
405416
/**
406417
* @brief Get the maximum depth.

include/xgboost/tree_updater.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
2-
* Copyright 2014-2023 by XGBoost Contributors
3-
* \file tree_updater.h
4-
* \brief General primitive for tree learning,
2+
* Copyright 2014-2025, XGBoost Contributors
3+
*
4+
* @brief General primitive for tree learning,
55
* Updating a collection of trees given the information.
66
* \author Tianqi Chen
77
*/
@@ -10,16 +10,17 @@
1010

1111
#include <dmlc/registry.h>
1212
#include <xgboost/base.h> // for Args, GradientPair
13-
#include <xgboost/data.h> // DMatrix
13+
#include <xgboost/data.h> // for DMatrix
14+
#include <xgboost/gradient.h> // for GradientContainer
1415
#include <xgboost/host_device_vector.h> // for HostDeviceVector
1516
#include <xgboost/linalg.h> // for VectorView
1617
#include <xgboost/model.h> // for Configurable
1718
#include <xgboost/span.h> // for Span
1819
#include <xgboost/tree_model.h> // for RegTree
1920

20-
#include <functional> // for function
21-
#include <string> // for string
22-
#include <vector> // for vector
21+
#include <functional> // for function
22+
#include <string> // for string
23+
#include <vector> // for vector
2324

2425
namespace xgboost {
2526
namespace tree {
@@ -59,21 +60,21 @@ class TreeUpdater : public Configurable {
5960
*/
6061
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
6162
/**
62-
* \brief perform update to the tree models
63+
* @brief perform update to the tree models
6364
*
64-
* \param param Hyper-parameter for constructing trees.
65-
* \param gpair the gradient pair statistics of the data
66-
* \param data The data matrix passed to the updater.
67-
* \param out_position The leaf index for each row. The index is negated if that row is
65+
* @param param Hyper-parameter for constructing trees.
66+
* @param gpair The gradient pair statistics of the data
67+
* @param p_fmat The data matrix passed to the updater.
68+
* @param out_position The leaf index for each row. The index is negated if that row is
6869
* removed during sampling. So the 3th node is ~3.
69-
* \param out_trees references the trees to be updated, updater will change the content of trees
70+
* @param out_trees references the trees to be updated, updater will change the content of trees
7071
* note: all the trees in the vector are updated, with the same statistics,
7172
* but maybe different random seeds, usually one tree is passed in at a time,
7273
* there can be multiple trees when we train random forest style model
7374
*/
74-
virtual void Update(tree::TrainParam const* param, linalg::Matrix<GradientPair>* gpair,
75-
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
76-
const std::vector<RegTree*>& out_trees) = 0;
75+
virtual void Update(tree::TrainParam const* param, GradientContainer* gpair, DMatrix* p_fmat,
76+
common::Span<HostDeviceVector<bst_node_t>> out_position,
77+
std::vector<RegTree*> const& out_trees) = 0;
7778

7879
/*!
7980
* \brief determines whether updater has enough knowledge about a given dataset

plugin/sycl/tree/updater_quantile_hist.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma GCC diagnostic push
99
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
1010
#pragma GCC diagnostic ignored "-W#pragma-messages"
11+
#include "xgboost/gradient.h" // for GradientContainer
1112
#include "xgboost/tree_updater.h"
1213
#pragma GCC diagnostic pop
1314

@@ -72,11 +73,11 @@ void QuantileHistMaker::CallUpdate(
7273
}
7374
}
7475

75-
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
76-
linalg::Matrix<GradientPair>* gpair,
76+
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, GradientContainer *in_gpair,
7777
DMatrix *dmat,
7878
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
7979
const std::vector<RegTree *> &trees) {
80+
auto gpair = in_gpair->FullGradOnly();
8081
gpair->Data()->SetDevice(ctx_->Device());
8182
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
8283
updater_monitor_.Start("GmatInitialization");

plugin/sycl/tree/updater_quantile_hist.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
/*!
2-
* Copyright 2017-2024 by Contributors
1+
/**
2+
* Copyright 2017-2025, XGBoost Contributors
33
* \file updater_quantile_hist.h
44
*/
55
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
@@ -8,21 +8,21 @@
88
#include <dmlc/timer.h>
99
#include <xgboost/tree_updater.h>
1010

11-
#include <vector>
1211
#include <memory>
12+
#include <vector>
1313

14-
#include "../data/gradient_index.h"
14+
#include "../../src/common/random.h"
15+
#include "../../src/tree/constraints.h"
1516
#include "../common/hist_util.h"
16-
#include "../common/row_set.h"
1717
#include "../common/partition_builder.h"
18-
#include "split_evaluator.h"
18+
#include "../common/row_set.h"
19+
#include "../data/gradient_index.h"
1920
#include "../device_manager.h"
2021
#include "hist_updater.h"
22+
#include "split_evaluator.h"
2123
#include "xgboost/data.h"
22-
24+
#include "xgboost/gradient.h" // for GradientContainer
2325
#include "xgboost/json.h"
24-
#include "../../src/tree/constraints.h"
25-
#include "../../src/common/random.h"
2626

2727
namespace xgboost {
2828
namespace sycl {
@@ -48,9 +48,7 @@ class QuantileHistMaker: public TreeUpdater {
4848
}
4949
void Configure(const Args& args) override;
5050

51-
void Update(xgboost::tree::TrainParam const *param,
52-
linalg::Matrix<GradientPair>* gpair,
53-
DMatrix* dmat,
51+
void Update(xgboost::tree::TrainParam const* param, GradientContainer* in_gpair, DMatrix* dmat,
5452
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
5553
const std::vector<RegTree*>& trees) override;
5654

0 commit comments

Comments
 (0)