Skip to content

Commit 2cace85

Browse files
authored
Revert "fix training continuation for iGPUs (#71)"
This reverts commit 3d067f4.
1 parent 3d067f4 commit 2cace85

File tree

5 files changed

+89
-52
lines changed

5 files changed

+89
-52
lines changed

plugin/sycl/data/gradient_index.cc

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,28 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
5050

5151
template <typename BinIdxType, bool isDense>
5252
void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
53-
Context const * ctx,
5453
BinIdxType* index_data,
55-
DMatrix *dmat) {
54+
DMatrix *dmat,
55+
size_t nbins,
56+
size_t row_stride) {
5657
if (nbins == 0) return;
5758
const bst_float* cut_values = cut.cut_values_.ConstDevicePointer();
5859
const uint32_t* cut_ptrs = cut.cut_ptrs_.ConstDevicePointer();
5960
size_t* hit_count_ptr = hit_count.DevicePointer();
6061

6162
BinIdxType* sort_data = reinterpret_cast<BinIdxType*>(sort_buff.Data());
6263

64+
::sycl::event event;
6365
for (auto &batch : dmat->GetBatches<SparsePage>()) {
64-
batch.data.SetDevice(ctx->Device());
65-
batch.offset.SetDevice(ctx->Device());
66-
67-
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
68-
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
69-
size_t batch_size = batch.Size();
70-
if (batch_size > 0) {
71-
const auto base_rowid = batch.base_rowid;
72-
size_t row_stride = this->row_stride;
73-
size_t nbins = this->nbins;
74-
qu->submit([&](::sycl::handler& cgh) {
75-
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::item<1> pid) {
66+
for (auto &batch : dmat->GetBatches<SparsePage>()) {
67+
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
68+
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
69+
size_t batch_size = batch.Size();
70+
if (batch_size > 0) {
71+
const auto base_rowid = batch.base_rowid;
72+
event = qu->submit([&](::sycl::handler& cgh) {
73+
cgh.depends_on(event);
74+
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::item<1> pid) {
7675
const size_t i = pid.get_id(0);
7776
const size_t ibegin = offset_vec[i];
7877
const size_t iend = offset_vec[i + 1];
@@ -93,22 +92,23 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
9392
}
9493
});
9594
});
96-
qu->wait();
95+
}
9796
}
9897
}
98+
qu->wait();
9999
}
100100

101-
void GHistIndexMatrix::ResizeIndex(::sycl::queue* qu, size_t n_index) {
102-
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense_) {
101+
void GHistIndexMatrix::ResizeIndex(size_t n_index, bool isDense) {
102+
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
103103
index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize);
104-
index.Resize(qu, (sizeof(uint8_t)) * n_index);
104+
index.Resize((sizeof(uint8_t)) * n_index);
105105
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
106-
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense_) {
106+
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
107107
index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize);
108-
index.Resize(qu, (sizeof(uint16_t)) * n_index);
108+
index.Resize((sizeof(uint16_t)) * n_index);
109109
} else {
110110
index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize);
111-
index.Resize(qu, (sizeof(uint32_t)) * n_index);
111+
index.Resize((sizeof(uint32_t)) * n_index);
112112
}
113113
}
114114

@@ -122,50 +122,52 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122122
cut.SetDevice(ctx->Device());
123123

124124
max_num_bins = max_bins;
125-
nbins = cut.Ptrs().back();
125+
const uint32_t nbins = cut.Ptrs().back();
126+
this->nbins = nbins;
126127

127128
hit_count.SetDevice(ctx->Device());
128129
hit_count.Resize(nbins, 0);
129130

131+
this->p_fmat = dmat;
130132
const bool isDense = dmat->IsDense();
131133
this->isDense_ = isDense;
132134

135+
index.setQueue(qu);
136+
133137
row_stride = 0;
134138
size_t n_rows = 0;
135-
if (!isDense) {
136-
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
137-
const auto& row_offset = batch.offset.ConstHostVector();
138-
n_rows += batch.Size();
139-
for (auto i = 1ull; i < row_offset.size(); i++) {
140-
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
141-
}
139+
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
140+
const auto& row_offset = batch.offset.ConstHostVector();
141+
batch.data.SetDevice(ctx->Device());
142+
batch.offset.SetDevice(ctx->Device());
143+
n_rows += batch.Size();
144+
for (auto i = 1ull; i < row_offset.size(); i++) {
145+
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
142146
}
143-
} else {
144-
row_stride = nfeatures;
145-
n_rows = dmat->Info().num_row_;
146147
}
147148

148149
const size_t n_offsets = cut.cut_ptrs_.Size() - 1;
149150
const size_t n_index = n_rows * row_stride;
150-
ResizeIndex(qu, n_index);
151+
ResizeIndex(n_index, isDense);
151152

152153
CHECK_GT(cut.cut_values_.Size(), 0U);
153154

154155
if (isDense) {
155156
BinTypeSize curent_bin_size = index.GetBinTypeSize();
156157
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) {
157-
SetIndexData<uint8_t, true>(qu, ctx, index.data<uint8_t>(), dmat);
158+
SetIndexData<uint8_t, true>(qu, index.data<uint8_t>(), dmat, nbins, row_stride);
159+
158160
} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) {
159-
SetIndexData<uint16_t, true>(qu, ctx, index.data<uint16_t>(), dmat);
161+
SetIndexData<uint16_t, true>(qu, index.data<uint16_t>(), dmat, nbins, row_stride);
160162
} else {
161163
CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize);
162-
SetIndexData<uint32_t, true>(qu, ctx, index.data<uint32_t>(), dmat);
164+
SetIndexData<uint32_t, true>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
163165
}
164166
/* For sparse DMatrix we have to store index of feature for each bin
165167
in index field to chose right offset. So offset is nullptr and index is not reduced */
166168
} else {
167169
sort_buff.Resize(qu, n_rows * row_stride * sizeof(uint32_t));
168-
SetIndexData<uint32_t, false>(qu, ctx, index.data<uint32_t>(), dmat);
170+
SetIndexData<uint32_t, false>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
169171
}
170172
}
171173

plugin/sycl/data/gradient_index.h

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,21 @@ struct Index {
3131
Index& operator=(Index&& i) = delete;
3232
void SetBinTypeSize(BinTypeSize binTypeSize) {
3333
binTypeSize_ = binTypeSize;
34-
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
35-
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
36-
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
34+
switch (binTypeSize) {
35+
case BinTypeSize::kUint8BinsTypeSize:
36+
func_ = &GetValueFromUint8;
37+
break;
38+
case BinTypeSize::kUint16BinsTypeSize:
39+
func_ = &GetValueFromUint16;
40+
break;
41+
case BinTypeSize::kUint32BinsTypeSize:
42+
func_ = &GetValueFromUint32;
43+
break;
44+
default:
45+
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
46+
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
47+
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
48+
}
3749
}
3850
BinTypeSize GetBinTypeSize() const {
3951
return binTypeSize_;
@@ -53,8 +65,8 @@ struct Index {
5365
return data_.Size() / (binTypeSize_);
5466
}
5567

56-
void Resize(::sycl::queue* qu, const size_t nBytesData) {
57-
data_.Resize(qu, nBytesData);
68+
void Resize(const size_t nBytesData) {
69+
data_.Resize(qu_, nBytesData);
5870
}
5971

6072
uint8_t* begin() const {
@@ -65,9 +77,28 @@ struct Index {
6577
return data_.End();
6678
}
6779

80+
void setQueue(::sycl::queue* qu) {
81+
qu_ = qu;
82+
}
83+
6884
private:
85+
static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) {
86+
return reinterpret_cast<const uint8_t*>(t)[i];
87+
}
88+
static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) {
89+
return reinterpret_cast<const uint16_t*>(t)[i];
90+
}
91+
static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) {
92+
return reinterpret_cast<const uint32_t*>(t)[i];
93+
}
94+
95+
using Func = uint32_t (*)(const uint8_t*, size_t);
96+
6997
USMVector<uint8_t, MemoryType::on_device> data_;
7098
BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize};
99+
Func func_;
100+
101+
::sycl::queue* qu_;
71102
};
72103

73104
/*!
@@ -85,19 +116,22 @@ struct GHistIndexMatrix {
85116
USMVector<uint8_t, MemoryType::on_device> sort_buff;
86117
/*! \brief The corresponding cuts */
87118
xgboost::common::HistogramCuts cut;
119+
DMatrix* p_fmat;
88120
size_t max_num_bins;
89121
size_t nbins;
90122
size_t nfeatures;
91123
size_t row_stride;
92124

93125
// Create a global histogram matrix based on a given DMatrix device wrapper
94-
void Init(::sycl::queue* qu, Context const * ctx, DMatrix *dmat, int max_num_bins);
126+
void Init(::sycl::queue* qu, Context const * ctx,
127+
DMatrix *dmat, int max_num_bins);
95128

96129
template <typename BinIdxType, bool isDense>
97-
void SetIndexData(::sycl::queue* qu, Context const * ctx, BinIdxType* index_data,
98-
DMatrix *dmat);
130+
void SetIndexData(::sycl::queue* qu, BinIdxType* index_data,
131+
DMatrix *dmat,
132+
size_t nbins, size_t row_stride);
99133

100-
void ResizeIndex(::sycl::queue* qu, size_t n_index);
134+
void ResizeIndex(size_t n_index, bool isDense);
101135

102136
inline void GetFeatureCounts(size_t* counts) const {
103137
auto nfeature = cut.cut_ptrs_.Size() - 1;

plugin/sycl/predictor/predictor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class Predictor : public xgboost::Predictor {
291291
}
292292

293293
if (num_group == 1) {
294-
float& sum = out_predictions[row_idx];
294+
float sum = 0.0;
295295
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
296296
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
297297
if constexpr (any_missing) {
@@ -300,6 +300,7 @@ class Predictor : public xgboost::Predictor {
300300
sum += GetLeafWeight(first_node, fval_buff_row_ptr);
301301
}
302302
}
303+
out_predictions[row_idx] += sum;
303304
} else {
304305
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
305306
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
@@ -332,6 +333,7 @@ class Predictor : public xgboost::Predictor {
332333
int num_features = dmat->Info().num_col_;
333334

334335
float* out_predictions = out_preds->DevicePointer();
336+
::sycl::event event;
335337
for (auto &batch : dmat->GetBatches<SparsePage>()) {
336338
batch.data.SetDevice(ctx_->Device());
337339
batch.offset.SetDevice(ctx_->Device());
@@ -341,7 +343,6 @@ class Predictor : public xgboost::Predictor {
341343
if (batch_size > 0) {
342344
const auto base_rowid = batch.base_rowid;
343345

344-
::sycl::event event;
345346
if (needs_buffer_update) {
346347
fval_buff.ResizeNoCopy(qu_, num_features * batch_size);
347348
if constexpr (any_missing) {
@@ -353,9 +354,9 @@ class Predictor : public xgboost::Predictor {
353354
row_ptr, batch_size, num_features,
354355
num_group, tree_begin, tree_end);
355356
needs_buffer_update = (batch_size != out_preds->Size());
356-
qu_->wait();
357357
}
358358
}
359+
qu_->wait();
359360
}
360361

361362
mutable USMVector<float, MemoryType::on_device> fval_buff;

tests/cpp/plugin/test_sycl_partition_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void TestPartitioning(float sparsity, int max_bins) {
6767

6868
std::vector<uint8_t> ridx_left(num_rows, 0);
6969
std::vector<uint8_t> ridx_right(num_rows, 0);
70-
for (auto &batch : p_fmat->GetBatches<SparsePage>()) {
70+
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
7171
const auto& data_vec = batch.data.HostVector();
7272
const auto& offset_vec = batch.offset.HostVector();
7373

tests/python-sycl/test_sycl_training_continuation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ class TestSYCLTrainingContinuation:
99
def run_training_continuation(self, use_json):
1010
kRows = 64
1111
kCols = 32
12-
X = rng.randn(kRows, kCols)
13-
y = rng.randn(kRows)
12+
X = np.random.randn(kRows, kCols)
13+
y = np.random.randn(kRows)
1414
dtrain = xgb.DMatrix(X, y)
1515
params = {
1616
"device": "sycl",

0 commit comments

Comments
 (0)