Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>

#include <vector>
#include <string>
#include <functional>
#include <memory>
#include <string>
#include <vector>

namespace xgboost {

Expand Down Expand Up @@ -78,8 +79,8 @@ class GradientBooster : public Model, public Configurable {
* the booster may change content of gpair
* @param obj The objective function used for boosting.
*/
virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
PredictionCacheEntry*, ObjFunction const* obj) = 0;
virtual void DoBoost(DMatrix* p_fmat, GradientContainer* in_gpair,
PredictionCacheEntry* prediction, ObjFunction const* obj) = 0;

/**
* \brief Generate predictions for given feature matrix
Expand Down
52 changes: 52 additions & 0 deletions include/xgboost/gradient.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright 2025, XGBoost Contributors
*/
#pragma once

#include <xgboost/base.h> // for GradientPair
#include <xgboost/linalg.h> // for Matrix
#include <xgboost/logging.h>

#include <cstddef> // for size_t

namespace xgboost {
/**
* @brief Container for gradient produced by objective.
*/
struct GradientContainer {
/** @brief Gradient used for multi-target tree split and linear model. */
linalg::Matrix<GradientPair> gpair;
/** @brief Gradient used for tree leaf value, optional. */
linalg::Matrix<GradientPair> value_gpair;

[[nodiscard]] bool HasValueGrad() const noexcept { return !value_gpair.Empty(); }

[[nodiscard]] std::size_t NumSplitTargets() const noexcept { return gpair.Shape(1); }
[[nodiscard]] std::size_t NumTargets() const noexcept {
return HasValueGrad() ? value_gpair.Shape(1) : this->gpair.Shape(1);
}

linalg::MatrixView<GradientPair const> ValueGrad(Context const* ctx) const {
if (HasValueGrad()) {
return this->value_gpair.View(ctx->Device());
}
return this->gpair.View(ctx->Device());
}

[[nodiscard]] linalg::Matrix<GradientPair> const* Grad() const { return &gpair; }
[[nodiscard]] linalg::Matrix<GradientPair>* Grad() { return &gpair; }

[[nodiscard]] linalg::Matrix<GradientPair> const* FullGradOnly() const {
if (this->HasValueGrad()) {
LOG(FATAL) << "Reduced gradient is not yet supported.";
}
return this->Grad();
}
[[nodiscard]] linalg::Matrix<GradientPair>* FullGradOnly() {
if (this->HasValueGrad()) {
LOG(FATAL) << "Reduced gradient is not yet supported.";
}
return this->Grad();
}
};
} // namespace xgboost
46 changes: 23 additions & 23 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <dmlc/io.h> // for Serializable
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for Vector, VectorView
#include <xgboost/metric.h> // for Metric
#include <xgboost/model.h> // for Configurable, Model
#include <xgboost/span.h> // for Span
#include <xgboost/task.h> // for ObjInfo
#include <dmlc/io.h> // for Serializable
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
#include <xgboost/context.h> // for Context
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/linalg.h> // for Vector, VectorView
#include <xgboost/metric.h> // for Metric
#include <xgboost/model.h> // for Configurable, Model
#include <xgboost/span.h> // for Span
#include <xgboost/task.h> // for ObjInfo

#include <algorithm> // for max
#include <cstdint> // for int32_t, uint32_t, uint8_t
#include <map> // for map
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector
#include <algorithm> // for max
#include <cstdint> // for int32_t, uint32_t, uint8_t
#include <map> // for map
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector

namespace xgboost {
class FeatureMap;
Expand All @@ -47,25 +48,24 @@ enum class PredictionType : std::uint8_t { // NOLINT
kLeaf = 6
};

/*!
* \brief Learner class that does training and prediction.
/**
* @brief Learner class that does training and prediction.
* This is the user facing module of xgboost training.
* The Load/Save function corresponds to the model used in python/R.
* \code
* @code
*
* std::unique_ptr<Learner> learner(new Learner::Create(cache_mats));
* learner.Configure(configs);
* std::unique_ptr<Learner> learner{Learner::Create(cache_mats)};
* learner->Configure(configs);
*
* for (int iter = 0; iter < max_iter; ++iter) {
* learner->UpdateOneIter(iter, train_mat);
* LOG(INFO) << learner->EvalOneIter(iter, data_sets, data_names);
* }
*
* \endcode
* @endcode
*/
class Learner : public Model, public Configurable, public dmlc::Serializable {
public:
/*! \brief virtual destructor */
~Learner() override;
/*!
* \brief Configure Learner based on set parameters.
Expand All @@ -88,7 +88,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* @param in_gpair The input gradient statistics.
*/
virtual void BoostOneIter(std::int32_t iter, std::shared_ptr<DMatrix> train,
linalg::Matrix<GradientPair>* in_gpair) = 0;
GradientContainer* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number
Expand Down
13 changes: 12 additions & 1 deletion include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ template <typename T>
using Vector = Tensor<T, 1>;

/**
* \brief Create an array without initialization.
* @brief Create an array without initialization.
*/
template <typename T, typename... Index>
auto Empty(Context const *ctx, Index &&...index) {
Expand All @@ -967,6 +967,17 @@ auto Empty(Context const *ctx, Index &&...index) {
return t;
}

/**
* @brief Create an array with the same shape and dtype as the input.
*/
template <typename T, std::int32_t kDim>
auto EmptyLike(Context const *ctx, Tensor<T, kDim> const &in) {
Tensor<T, kDim> t;
t.SetDevice(ctx->Device());
t.Reshape(in.Shape());
return t;
}

/**
* \brief Create an array with value v.
*/
Expand Down
6 changes: 4 additions & 2 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@ class MultiTargetTree : public Model {
MultiTargetTree& operator=(MultiTargetTree&& that) = delete;

/**
* @brief Set the weight for a leaf.
* @brief Set the weight for the root.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
void SetRoot(linalg::VectorView<float const> weight);
/**
* @brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
/** @see RegTree::SetLeaves */
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);

[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
return left_.ConstHostVector()[nidx] == InvalidNodeId();
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ObjFunction : public Configurable {
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
MetaInfo const& /*info*/, float /*learning_rate*/,
HostDeviceVector<float> const& /*prediction*/,
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
bst_target_t /*group_idx*/, RegTree* /*p_tree*/) const {}
/**
* @brief Create an objective function according to the name.
*
Expand Down
19 changes: 15 additions & 4 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,23 @@ class RegTree : public Model {
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId);
/**
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
* @brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
*/
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
/**
* @brief Set all leaf weights for a multi-target tree.
*
* The leaf weight can be different from the internal weight stored by @ref ExpandNode
* This function is used to set the leaf at the end of tree construction.
*
* @param leaves The node indices for all leaves. This must contain all the leaves in this tree.
* @param weights Row-major matrix for leaf weights, each row contains a leaf specified by the
* leaves parameter.
*/
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);

/**
* \brief Expands a leaf node with categories
Expand Down Expand Up @@ -396,11 +407,11 @@ class RegTree : public Model {
*/
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
/**
* @brief Set the leaf weight for a multi-target tree.
* @brief Set the root weight for a multi-target tree.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
void SetRoot(linalg::VectorView<float const> weight) {
CHECK(IsMultiTarget());
return this->p_mt_tree_->SetLeaf(nidx, weight);
return this->p_mt_tree_->SetRoot(weight);
}
/**
* @brief Get the maximum depth.
Expand Down
33 changes: 17 additions & 16 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file tree_updater.h
* \brief General primitive for tree learning,
* Copyright 2014-2025, XGBoost Contributors
*
* @brief General primitive for tree learning,
* Updating a collection of trees given the information.
* \author Tianqi Chen
*/
Expand All @@ -10,16 +10,17 @@

#include <dmlc/registry.h>
#include <xgboost/base.h> // for Args, GradientPair
#include <xgboost/data.h> // DMatrix
#include <xgboost/data.h> // for DMatrix
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Configurable
#include <xgboost/span.h> // for Span
#include <xgboost/tree_model.h> // for RegTree

#include <functional> // for function
#include <string> // for string
#include <vector> // for vector
#include <functional> // for function
#include <string> // for string
#include <vector> // for vector

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -59,21 +60,21 @@ class TreeUpdater : public Configurable {
*/
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
/**
* \brief perform update to the tree models
* @brief perform update to the tree models
*
* \param param Hyper-parameter for constructing trees.
* \param gpair the gradient pair statistics of the data
* \param data The data matrix passed to the updater.
* \param out_position The leaf index for each row. The index is negated if that row is
* @param param Hyper-parameter for constructing trees.
* @param gpair The gradient pair statistics of the data
* @param p_fmat The data matrix passed to the updater.
* @param out_position The leaf index for each row. The index is negated if that row is
* removed during sampling. So the 3th node is ~3.
* \param out_trees references the trees to be updated, updater will change the content of trees
* @param out_trees references the trees to be updated, updater will change the content of trees
* note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model
*/
virtual void Update(tree::TrainParam const* param, linalg::Matrix<GradientPair>* gpair,
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& out_trees) = 0;
virtual void Update(tree::TrainParam const* param, GradientContainer* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
std::vector<RegTree*> const& out_trees) = 0;

/*!
* \brief determines whether updater has enough knowledge about a given dataset
Expand Down
5 changes: 3 additions & 2 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "xgboost/gradient.h" // for GradientContainer
#include "xgboost/tree_updater.h"
#pragma GCC diagnostic pop

Expand Down Expand Up @@ -72,11 +73,11 @@ void QuantileHistMaker::CallUpdate(
}
}

void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, GradientContainer *in_gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
auto gpair = in_gpair->FullGradOnly();
gpair->Data()->SetDevice(ctx_->Device());
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
Expand Down
22 changes: 10 additions & 12 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2024 by Contributors
/**
* Copyright 2017-2025, XGBoost Contributors
* \file updater_quantile_hist.h
*/
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
Expand All @@ -8,21 +8,21 @@
#include <dmlc/timer.h>
#include <xgboost/tree_updater.h>

#include <vector>
#include <memory>
#include <vector>

#include "../data/gradient_index.h"
#include "../../src/common/random.h"
#include "../../src/tree/constraints.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/partition_builder.h"
#include "split_evaluator.h"
#include "../common/row_set.h"
#include "../data/gradient_index.h"
#include "../device_manager.h"
#include "hist_updater.h"
#include "split_evaluator.h"
#include "xgboost/data.h"

#include "xgboost/gradient.h" // for GradientContainer
#include "xgboost/json.h"
#include "../../src/tree/constraints.h"
#include "../../src/common/random.h"

namespace xgboost {
namespace sycl {
Expand All @@ -48,9 +48,7 @@ class QuantileHistMaker: public TreeUpdater {
}
void Configure(const Args& args) override;

void Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix* dmat,
void Update(xgboost::tree::TrainParam const* param, GradientContainer* in_gpair, DMatrix* dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override;

Expand Down
Loading
Loading