Skip to content

Commit 62a3a4b

Browse files
authored
[mt] Split up the tree model weights, support extmem. (#11814)
1 parent 0e9bc36 commit 62a3a4b

File tree

20 files changed

+429
-117
lines changed

20 files changed

+429
-117
lines changed

include/xgboost/multi_target_tree_model.h

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ struct TreeParam;
2525

2626
/**
2727
* @brief Tree structure for multi-target model.
28+
*
29+
* In order to support reduced gradient, the internal storage distinguishes weights
30+
* between base weights and leaf weights. The former is the weight calculated from split
31+
* gradient, and the later is the weight calculated from value gradient and used as
32+
* outputs. Every node has a base weight, but only leaves have leaf weights.
33+
*
34+
* To access the leaf weights, we re-use the right child to store leaf indices. For split
35+
* nodes, the `right_` member stores their right child node indices, for leaf nodes, the
36+
* `right_` member stores the corresponding leaf weight indices.
2837
*/
2938
class MultiTargetTree : public Model {
3039
public:
@@ -33,24 +42,36 @@ class MultiTargetTree : public Model {
3342

3443
private:
3544
TreeParam const* param_;
45+
// Mapping from node index to its left child. -1 for a leaf node.
3646
HostDeviceVector<bst_node_t> left_;
47+
// Mapping from node index to its right child. Maps to leaf weight for a leaf node.
3748
HostDeviceVector<bst_node_t> right_;
49+
// Mapping from node index to its parent.
3850
HostDeviceVector<bst_node_t> parent_;
51+
// Feature index for node split.
3952
HostDeviceVector<bst_feature_t> split_index_;
53+
// Whether the left child is the default node when split feature is missing.
4054
HostDeviceVector<std::uint8_t> default_left_;
55+
// Threshold for splitting a node.
4156
HostDeviceVector<float> split_conds_;
57+
// Internal base weights.
4258
HostDeviceVector<float> weights_;
59+
// Output weights.
60+
HostDeviceVector<float> leaf_weights_;
4361

4462
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
45-
auto beg = nidx * this->NumTargets();
46-
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTargets());
63+
auto beg = nidx * this->NumSplitTargets();
64+
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumSplitTargets());
4765
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
4866
}
49-
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
50-
auto beg = nidx * this->NumTargets();
51-
auto v = this->weights_.HostSpan().subspan(beg, this->NumTargets());
67+
// Unlike the const version, `NumSplitTargets` is not reliable if the tree can change.
68+
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx,
69+
bst_target_t n_split_targets) {
70+
auto beg = nidx * n_split_targets;
71+
auto v = this->weights_.HostSpan().subspan(beg, n_split_targets);
5272
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
5373
}
74+
[[nodiscard]] bst_node_t LeafIdx(bst_node_t nidx) const { return this->RightChild(nidx); }
5475

5576
public:
5677
explicit MultiTargetTree(TreeParam const* param);
@@ -72,6 +93,8 @@ class MultiTargetTree : public Model {
7293
linalg::VectorView<float const> right_weight);
7394
/** @see RegTree::SetLeaves */
7495
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
96+
/** @brief Copy base weight into leaf weight for a non-reduced multi-target tree. */
97+
void SetLeaves();
7598

7699
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
77100
return left_.ConstHostVector()[nidx] == InvalidNodeId();
@@ -82,24 +105,36 @@ class MultiTargetTree : public Model {
82105
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
83106
return right_.ConstHostVector().at(nidx);
84107
}
85-
108+
/**
109+
* @brief Number of targets (size of a leaf).
110+
*/
86111
[[nodiscard]] bst_target_t NumTargets() const;
87-
[[nodiscard]] auto NumLeaves() const { return this->weights_.Size() / this->NumTargets(); }
112+
/**
113+
* @brief Number of reduced targets.
114+
*/
115+
[[nodiscard]] bst_target_t NumSplitTargets() const;
116+
[[nodiscard]] auto NumLeaves() const { return this->leaf_weights_.Size() / this->NumTargets(); }
88117

89118
[[nodiscard]] std::size_t Size() const;
90119
[[nodiscard]] MultiTargetTree* Copy(TreeParam const* param) const;
91120

92-
common::Span<float const> Weights(DeviceOrd device) const {
121+
common::Span<float const> LeafWeights(DeviceOrd device) const {
93122
if (device.IsCPU()) {
94-
return this->weights_.ConstHostSpan();
123+
return this->leaf_weights_.ConstHostSpan();
95124
}
96-
this->weights_.SetDevice(device);
97-
return this->weights_.ConstDeviceSpan();
125+
this->leaf_weights_.SetDevice(device);
126+
return this->leaf_weights_.ConstDeviceSpan();
98127
}
99128

100129
[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
101130
CHECK(IsLeaf(nidx));
102-
return this->NodeWeight(nidx);
131+
auto n_targets = this->NumTargets();
132+
auto h_leaf_mapping = this->right_.ConstHostSpan();
133+
auto h_leaf_weights = this->leaf_weights_.ConstHostSpan();
134+
auto lidx = h_leaf_mapping[nidx];
135+
CHECK_NE(lidx, InvalidNodeId());
136+
auto weight = h_leaf_weights.subspan(lidx * n_targets, n_targets);
137+
return linalg::MakeVec(DeviceOrd::CPU(), weight);
103138
}
104139

105140
void LoadModel(Json const& in) override;

include/xgboost/tree_model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ class RegTree : public Model {
408408
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
409409
/**
410410
* @brief Set the root weight for a multi-target tree.
411+
*
412+
* @param weight Internal split weight, with size equals to reduced targets.
411413
*/
412414
void SetRoot(linalg::VectorView<float const> weight) {
413415
CHECK(IsMultiTarget());

python-package/xgboost/testing/data.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from numpy.random import Generator as RNG
3131
from scipy import sparse
3232

33+
from ..compat import concat
3334
from ..core import DataIter, DMatrix, QuantileDMatrix
3435
from ..data import is_pd_cat_dtype, pandas_pyarrow_mapper
3536
from ..sklearn import ArrayLike, XGBRanker
@@ -1150,11 +1151,8 @@ def as_arrays(
11501151
self,
11511152
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]:
11521153
"""Return concatenated arrays."""
1153-
if isinstance(self.X[0], sparse.csr_matrix):
1154-
X = sparse.vstack(self.X, format="csr")
1155-
else:
1156-
X = np.concatenate(self.X, axis=0)
1157-
y = np.concatenate(self.y, axis=0)
1154+
X = concat(self.X)
1155+
y = concat(self.y)
11581156
if self.w:
11591157
w = np.concatenate(self.w, axis=0)
11601158
else:

python-package/xgboost/testing/multi_target.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import xgboost.testing as tm
1414

1515
from .._typing import ArrayLike
16-
from ..core import Booster, DMatrix, QuantileDMatrix
16+
from ..compat import import_cupy
17+
from ..core import Booster, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix
1718
from ..objective import Objective, TreeObjective
1819
from ..sklearn import XGBClassifier
1920
from ..training import train
21+
from .data import IteratorForTest
2022
from .updater import ResetStrategy
21-
from .utils import Device
23+
from .utils import Device, assert_allclose
2224

2325

2426
def run_multiclass(device: Device, learning_rate: Optional[float]) -> None:
@@ -64,34 +66,42 @@ def run_multilabel(device: Device, learning_rate: Optional[float]) -> None:
6466
assert proba.shape == y.shape
6567

6668

67-
def run_reduced_grad(device: Device) -> None:
68-
"""Basic test for using reduced gradient for tree splits."""
69-
import cupy as cp
69+
class LsObj0(TreeObjective):
70+
"""Split grad is the same as value grad."""
7071

71-
class LsObj0(TreeObjective):
72-
"""Split grad is the same as value grad."""
72+
def __call__(
73+
self, y_pred: ArrayLike, dtrain: DMatrix
74+
) -> Tuple[ArrayLike, ArrayLike]:
75+
cp = import_cupy()
7376

74-
def __call__(
75-
self, y_pred: ArrayLike, dtrain: DMatrix
76-
) -> Tuple[cp.ndarray, cp.ndarray]:
77-
y_true = dtrain.get_label().reshape(y_pred.shape)
78-
grad, hess = tm.ls_obj(y_true, y_pred, None)
79-
return cp.array(grad), cp.array(hess)
77+
y_true = dtrain.get_label().reshape(y_pred.shape)
78+
grad, hess = tm.ls_obj(y_true, y_pred, None)
79+
return cp.array(grad), cp.array(hess)
8080

81-
def split_grad(
82-
self, grad: ArrayLike, hess: ArrayLike
83-
) -> Tuple[ArrayLike, ArrayLike]:
84-
return cp.array(grad), cp.array(hess)
81+
def split_grad(
82+
self, grad: ArrayLike, hess: ArrayLike
83+
) -> Tuple[ArrayLike, ArrayLike]:
84+
cp = import_cupy()
8585

86-
class LsObj1(Objective):
87-
"""No split grad."""
86+
return cp.array(grad), cp.array(hess)
8887

89-
def __call__(
90-
self, y_pred: ArrayLike, dtrain: DMatrix
91-
) -> Tuple[cp.ndarray, cp.ndarray]:
92-
y_true = dtrain.get_label().reshape(y_pred.shape)
93-
grad, hess = tm.ls_obj(y_true, y_pred, None)
94-
return cp.array(grad), cp.array(hess)
88+
89+
class LsObj1(Objective):
90+
"""No split grad."""
91+
92+
def __call__(
93+
self, y_pred: ArrayLike, dtrain: DMatrix
94+
) -> Tuple[ArrayLike, ArrayLike]:
95+
cp = import_cupy()
96+
97+
y_true = dtrain.get_label().reshape(y_pred.shape)
98+
grad, hess = tm.ls_obj(y_true, y_pred, None)
99+
return cp.array(grad), cp.array(hess)
100+
101+
102+
def run_reduced_grad(device: Device) -> None:
103+
"""Basic test for using reduced gradient for tree splits."""
104+
import cupy as cp
95105

96106
X, y = make_regression( # pylint: disable=unbalanced-tuple-unpacking
97107
n_samples=1024, n_features=16, random_state=1994, n_targets=5
@@ -149,3 +159,85 @@ def split_grad(
149159
run_test(LsObj2(False))
150160
with pytest.raises(AssertionError):
151161
run_test(LsObj2(True))
162+
163+
164+
def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
165+
"""Test vector leaf with external memory."""
166+
if device == "cuda":
167+
from cupy import asarray
168+
else:
169+
from numpy import asarray
170+
171+
n_batches = 4
172+
n_rounds = 8
173+
n_targets = 3
174+
intercept = [0.5] * n_targets
175+
Xs = []
176+
ys = []
177+
for i in range(n_batches):
178+
X_i, y_i = make_regression( # pylint: disable=unbalanced-tuple-unpacking
179+
n_samples=4096, n_features=8, random_state=(i + 1), n_targets=n_targets
180+
)
181+
Xs.append(asarray(X_i))
182+
ys.append(asarray(y_i))
183+
it = IteratorForTest(Xs, ys, None, cache="cache", on_host=True)
184+
Xy: DMatrix = ExtMemQuantileDMatrix(it, cache_host_ratio=1.0)
185+
186+
evals_result_0: Dict[str, Dict] = {}
187+
booster_0 = train(
188+
{
189+
"device": device,
190+
"multi_strategy": "multi_output_tree",
191+
"learning_rate": 1.0,
192+
"base_score": intercept,
193+
},
194+
Xy,
195+
num_boost_round=n_rounds,
196+
evals=[(Xy, "Train")],
197+
evals_result=evals_result_0,
198+
)
199+
200+
it = IteratorForTest(Xs, ys, None, cache=None)
201+
Xy = QuantileDMatrix(it)
202+
evals_result_1: Dict[str, Dict] = {}
203+
booster_1 = train(
204+
{
205+
"device": device,
206+
"multi_strategy": "multi_output_tree",
207+
"learning_rate": 1.0,
208+
"base_score": intercept,
209+
},
210+
Xy,
211+
num_boost_round=n_rounds,
212+
evals=[(Xy, "Train")],
213+
evals_result=evals_result_1,
214+
)
215+
np.testing.assert_allclose(
216+
evals_result_0["Train"]["rmse"], evals_result_1["Train"]["rmse"]
217+
)
218+
assert tm.non_increasing(evals_result_0["Train"]["rmse"])
219+
X, _, _ = it.as_arrays()
220+
assert_allclose(device, booster_0.inplace_predict(X), booster_1.inplace_predict(X))
221+
222+
it = IteratorForTest(Xs, ys, None, cache="cache", on_host=True)
223+
Xy = ExtMemQuantileDMatrix(it, cache_host_ratio=1.0)
224+
225+
evals_result_2: Dict[str, Dict] = {}
226+
booster_2 = train(
227+
{
228+
"device": device,
229+
"multi_strategy": "multi_output_tree",
230+
"learning_rate": 1.0,
231+
"base_score": intercept,
232+
"debug_synchronize": True,
233+
},
234+
Xy,
235+
evals=[(Xy, "Train")],
236+
obj=LsObj0(),
237+
num_boost_round=n_rounds,
238+
evals_result=evals_result_2,
239+
)
240+
np.testing.assert_allclose(
241+
evals_result_0["Train"]["rmse"], evals_result_2["Train"]["rmse"]
242+
)
243+
assert_allclose(device, booster_0.inplace_predict(X), booster_2.inplace_predict(X))

src/common/cuda_rt_utils.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
*/
44
#include "cuda_rt_utils.h"
55

6+
#include "cuda_stream.h" // for StreamRef
7+
68
#if defined(XGBOOST_USE_CUDA)
79
#include <cuda_runtime_api.h>
810

@@ -99,6 +101,10 @@ void GetDrVersionGlobal(std::int32_t* major, std::int32_t* minor) {
99101
return numa_id;
100102
}
101103

104+
void MemcpyAsync(void* dst, const void* src, std::size_t count, StreamRef stream) {
105+
dh::safe_cuda(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream));
106+
}
107+
102108
#else
103109
std::int32_t AllVisibleGPUs() { return 0; }
104110

@@ -128,5 +134,7 @@ void SetDevice(std::int32_t device) {
128134
return 0;
129135
}
130136

137+
void MemcpyAsync(void*, const void*, std::size_t, StreamRef) { common::AssertGPUSupport(); }
138+
131139
#endif // !defined(XGBOOST_USE_CUDA)
132140
} // namespace xgboost::curt

src/common/cuda_rt_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <cstddef> // for size_t
66
#include <cstdint> // for int32_t
77

8+
#include "cuda_stream.h" // for StreamRef
9+
810
namespace xgboost::curt {
911
std::int32_t AllVisibleGPUs();
1012

@@ -35,4 +37,7 @@ void GetDrVersionGlobal(std::int32_t* major, std::int32_t* minor);
3537

3638
// Get the current device's numa ID.
3739
[[nodiscard]] std::int32_t GetNumaId();
40+
41+
// cudaMemcpyAsync
42+
void MemcpyAsync(void* dst, const void* src, std::size_t count, StreamRef stream);
3843
} // namespace xgboost::curt

src/common/cuda_stream.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
* Copyright 2022-2025, XGBoost contributors
33
*/
44
#pragma once
5+
6+
#if defined(XGBOOST_USE_CUDA)
57
#include <cuda_runtime.h>
8+
#endif // defined(XGBOOST_USE_CUDA)
69

710
#include <memory> // for unique_ptr
811
#include <utility> // for swap
912

1013
#include "common.h"
1114

1215
namespace xgboost::curt {
16+
#if defined(XGBOOST_USE_CUDA)
1317
class StreamRef;
1418

1519
class Event {
@@ -94,4 +98,12 @@ class Stream {
9498
void Sync() { this->View().Sync(); }
9599
void Wait(Event const &e) { this->View().Wait(e); }
96100
};
101+
#else
102+
class StreamRef {};
103+
104+
inline StreamRef DefaultStream() {
105+
common::AssertGPUSupport();
106+
return StreamRef{};
107+
}
108+
#endif
97109
} // namespace xgboost::curt

0 commit comments

Comments
 (0)