Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>

#include <vector>
#include <string>
#include <functional>
#include <memory>
#include <string>
#include <vector>

namespace xgboost {

Expand Down Expand Up @@ -78,8 +79,8 @@ class GradientBooster : public Model, public Configurable {
* the booster may change content of gpair
* @param obj The objective function used for boosting.
*/
virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
PredictionCacheEntry*, ObjFunction const* obj) = 0;
virtual void DoBoost(DMatrix* /*p_fmat*/, GradientContainer* /*in_gpair*/,
PredictionCacheEntry* /*prediction*/, ObjFunction const* /*obj*/) = 0;

/**
* \brief Generate predictions for given feature matrix
Expand Down
52 changes: 52 additions & 0 deletions include/xgboost/gradient.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright 2025, XGBoost Contributors
*/
#pragma once

#include <xgboost/base.h> // for GradientPair
#include <xgboost/linalg.h> // for Matrix
#include <xgboost/logging.h>

#include <cstddef> // for size_t

namespace xgboost {
/**
* @brief Container for gradient produced by objective.
*/
struct GradientContainer {
/** @brief Gradient used for multi-target tree split and linear model. */
linalg::Matrix<GradientPair> gpair;
/** @brief Gradient used for tree leaf value, optional. */
linalg::Matrix<GradientPair> value_gpair;

[[nodiscard]] bool HasValueGrad() const noexcept { return !value_gpair.Empty(); }

[[nodiscard]] std::size_t NumSplitTargets() const noexcept { return gpair.Shape(1); }
[[nodiscard]] std::size_t NumTargets() const noexcept {
return HasValueGrad() ? value_gpair.Shape(1) : this->gpair.Shape(1);
}

linalg::MatrixView<GradientPair const> ValueGrad(Context const* ctx) const {
if (HasValueGrad()) {
return this->value_gpair.View(ctx->Device());
}
return this->gpair.View(ctx->Device());
}

[[nodiscard]] linalg::Matrix<GradientPair> const* Grad() const { return &gpair; }
[[nodiscard]] linalg::Matrix<GradientPair>* Grad() { return &gpair; }

[[nodiscard]] linalg::Matrix<GradientPair> const* FullGradOnly() const {
if (this->HasValueGrad()) {
LOG(FATAL) << "Reduced gradient is not yet supported.";
}
return this->Grad();
}
[[nodiscard]] linalg::Matrix<GradientPair>* FullGradOnly() {
if (this->HasValueGrad()) {
LOG(FATAL) << "Reduced gradient is not yet supported.";
}
return this->Grad();
}
};
} // namespace xgboost
44 changes: 22 additions & 22 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <dmlc/io.h> // for Serializable
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for Vector, VectorView
#include <xgboost/metric.h> // for Metric
#include <xgboost/model.h> // for Configurable, Model
#include <xgboost/span.h> // for Span
#include <xgboost/task.h> // for ObjInfo
#include <dmlc/io.h> // for Serializable
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
#include <xgboost/context.h> // for Context
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/linalg.h> // for Vector, VectorView
#include <xgboost/metric.h> // for Metric
#include <xgboost/model.h> // for Configurable, Model
#include <xgboost/span.h> // for Span
#include <xgboost/task.h> // for ObjInfo

#include <algorithm> // for max
#include <cstdint> // for int32_t, uint32_t, uint8_t
#include <map> // for map
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector
#include <algorithm> // for max
#include <cstdint> // for int32_t, uint32_t, uint8_t
#include <map> // for map
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector

namespace xgboost {
class FeatureMap;
Expand All @@ -47,25 +48,24 @@ enum class PredictionType : std::uint8_t { // NOLINT
kLeaf = 6
};

/*!
* \brief Learner class that does training and prediction.
/**
* @brief Learner class that does training and prediction.
* This is the user facing module of xgboost training.
* The Load/Save function corresponds to the model used in python/R.
* \code
* @code
*
* std::unique_ptr<Learner> learner(new Learner::Create(cache_mats));
* learner.Configure(configs);
* learner->Configure(configs);
*
* for (int iter = 0; iter < max_iter; ++iter) {
* learner->UpdateOneIter(iter, train_mat);
* LOG(INFO) << learner->EvalOneIter(iter, data_sets, data_names);
* }
*
* \endcode
* @endcode
*/
class Learner : public Model, public Configurable, public dmlc::Serializable {
public:
/*! \brief virtual destructor */
~Learner() override;
/*!
* \brief Configure Learner based on set parameters.
Expand All @@ -88,7 +88,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* @param in_gpair The input gradient statistics.
*/
virtual void BoostOneIter(std::int32_t iter, std::shared_ptr<DMatrix> train,
linalg::Matrix<GradientPair>* in_gpair) = 0;
GradientContainer* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number
Expand Down
13 changes: 12 additions & 1 deletion include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ template <typename T>
using Vector = Tensor<T, 1>;

/**
* \brief Create an array without initialization.
* @brief Create an array without initialization.
*/
template <typename T, typename... Index>
auto Empty(Context const *ctx, Index &&...index) {
Expand All @@ -967,6 +967,17 @@ auto Empty(Context const *ctx, Index &&...index) {
return t;
}

/**
* @brief Create an array with the same shape and dtype as the input.
*/
template <typename T, std::int32_t kDim>
auto EmptyLike(Context const *ctx, Tensor<T, kDim> const &in) {
Tensor<T, kDim> t;
t.SetDevice(ctx->Device());
t.Reshape(in.Shape());
return t;
}

/**
* \brief Create an array with value v.
*/
Expand Down
5 changes: 3 additions & 2 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,17 @@ class MultiTargetTree : public Model {
MultiTargetTree& operator=(MultiTargetTree&& that) = delete;

/**
* @brief Set the weight for a leaf.
* @brief Set the weight for the root.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
void SetRoot(linalg::VectorView<float const> weight);
/**
* @brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);

[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
return left_.ConstHostVector()[nidx] == InvalidNodeId();
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ObjFunction : public Configurable {
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
MetaInfo const& /*info*/, float /*learning_rate*/,
HostDeviceVector<float> const& /*prediction*/,
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
bst_target_t /*group_idx*/, RegTree* /*p_tree*/) const {}
/**
* @brief Create an objective function according to the name.
*
Expand Down
19 changes: 15 additions & 4 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,23 @@ class RegTree : public Model {
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId);
/**
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
* @brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
*/
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
/**
* @brief Set all leaf weights for a multi-target tree.
*
* The leaf weight can be different from the internal weight stored by @ref ExpandNode
* This function is used to set the leaf at the end of tree construction.
*
* @param leaves The node indices for all leaves. This must contain all the leaves in this tree.
* @param weights Row-major matrix for leaf weights, each row contains a leaf specified by the
* leaves parameter.
*/
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);

/**
* \brief Expands a leaf node with categories
Expand Down Expand Up @@ -396,11 +407,11 @@ class RegTree : public Model {
*/
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
/**
* @brief Set the leaf weight for a multi-target tree.
* @brief Set the root weight for a multi-target tree.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
void SetRoot(linalg::VectorView<float const> weight) {
CHECK(IsMultiTarget());
return this->p_mt_tree_->SetLeaf(nidx, weight);
return this->p_mt_tree_->SetRoot(weight);
}
/**
* @brief Get the maximum depth.
Expand Down
33 changes: 17 additions & 16 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file tree_updater.h
* \brief General primitive for tree learning,
* Copyright 2014-2025, XGBoost Contributors
*
* @brief General primitive for tree learning,
* Updating a collection of trees given the information.
* \author Tianqi Chen
*/
Expand All @@ -10,16 +10,17 @@

#include <dmlc/registry.h>
#include <xgboost/base.h> // for Args, GradientPair
#include <xgboost/data.h> // DMatrix
#include <xgboost/data.h> // for DMatrix
#include <xgboost/gradient.h> // for GradientContainer
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Configurable
#include <xgboost/span.h> // for Span
#include <xgboost/tree_model.h> // for RegTree

#include <functional> // for function
#include <string> // for string
#include <vector> // for vector
#include <functional> // for function
#include <string> // for string
#include <vector> // for vector

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -59,21 +60,21 @@ class TreeUpdater : public Configurable {
*/
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
/**
* \brief perform update to the tree models
* @brief perform update to the tree models
*
* \param param Hyper-parameter for constructing trees.
* \param gpair the gradient pair statistics of the data
* \param data The data matrix passed to the updater.
* \param out_position The leaf index for each row. The index is negated if that row is
* @param param Hyper-parameter for constructing trees.
* @param gpair The gradient pair statistics of the data
* @param p_fmat The data matrix passed to the updater.
* @param out_position The leaf index for each row. The index is negated if that row is
* removed during sampling. So the 3th node is ~3.
* \param out_trees references the trees to be updated, updater will change the content of trees
* @param out_trees references the trees to be updated, updater will change the content of trees
* note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model
*/
virtual void Update(tree::TrainParam const* param, linalg::Matrix<GradientPair>* gpair,
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& out_trees) = 0;
virtual void Update(tree::TrainParam const* param, GradientContainer* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
std::vector<RegTree*> const& out_trees) = 0;

/*!
* \brief determines whether updater has enough knowledge about a given dataset
Expand Down
Loading
Loading