Skip to content

Commit cd22e8f

Browse files
author
Dmitry Razdoburdin
committed
fix training continuation for iGPUs
1 parent ed470ad commit cd22e8f

File tree

4 files changed

+50
-87
lines changed

4 files changed

+50
-87
lines changed

plugin/sycl/data/gradient_index.cc

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,29 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
5050

5151
template <typename BinIdxType, bool isDense>
5252
void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
53+
Context const * ctx,
5354
BinIdxType* index_data,
54-
DMatrix *dmat,
55-
size_t nbins,
56-
size_t row_stride) {
55+
DMatrix *dmat) {
5756
if (nbins == 0) return;
5857
const bst_float* cut_values = cut.cut_values_.ConstDevicePointer();
5958
const uint32_t* cut_ptrs = cut.cut_ptrs_.ConstDevicePointer();
6059
size_t* hit_count_ptr = hit_count.DevicePointer();
6160

6261
BinIdxType* sort_data = reinterpret_cast<BinIdxType*>(sort_buff.Data());
6362

64-
::sycl::event event;
6563
for (auto &batch : dmat->GetBatches<SparsePage>()) {
66-
for (auto &batch : dmat->GetBatches<SparsePage>()) {
67-
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
68-
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
69-
size_t batch_size = batch.Size();
70-
if (batch_size > 0) {
71-
const auto base_rowid = batch.base_rowid;
72-
event = qu->submit([&](::sycl::handler& cgh) {
73-
cgh.depends_on(event);
74-
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::item<1> pid) {
64+
batch.data.SetDevice(ctx->Device());
65+
batch.offset.SetDevice(ctx->Device());
66+
67+
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
68+
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
69+
size_t batch_size = batch.Size();
70+
if (batch_size > 0) {
71+
const auto base_rowid = batch.base_rowid;
72+
size_t row_stride = this->row_stride;
73+
size_t nbins = this->nbins;
74+
qu->submit([&](::sycl::handler& cgh) {
75+
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::item<1> pid) {
7576
const size_t i = pid.get_id(0);
7677
const size_t ibegin = offset_vec[i];
7778
const size_t iend = offset_vec[i + 1];
@@ -92,23 +93,22 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
9293
}
9394
});
9495
});
95-
}
96+
qu->wait();
9697
}
9798
}
98-
qu->wait();
9999
}
100100

101-
void GHistIndexMatrix::ResizeIndex(size_t n_index, bool isDense) {
102-
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
101+
void GHistIndexMatrix::ResizeIndex(::sycl::queue* qu, size_t n_index) {
102+
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense_) {
103103
index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize);
104-
index.Resize((sizeof(uint8_t)) * n_index);
104+
index.Resize(qu, (sizeof(uint8_t)) * n_index);
105105
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
106-
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
106+
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense_) {
107107
index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize);
108-
index.Resize((sizeof(uint16_t)) * n_index);
108+
index.Resize(qu, (sizeof(uint16_t)) * n_index);
109109
} else {
110110
index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize);
111-
index.Resize((sizeof(uint32_t)) * n_index);
111+
index.Resize(qu, (sizeof(uint32_t)) * n_index);
112112
}
113113
}
114114

@@ -122,52 +122,50 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122122
cut.SetDevice(ctx->Device());
123123

124124
max_num_bins = max_bins;
125-
const uint32_t nbins = cut.Ptrs().back();
126-
this->nbins = nbins;
125+
nbins = cut.Ptrs().back();
127126

128127
hit_count.SetDevice(ctx->Device());
129128
hit_count.Resize(nbins, 0);
130129

131-
this->p_fmat = dmat;
132130
const bool isDense = dmat->IsDense();
133131
this->isDense_ = isDense;
134132

135-
index.setQueue(qu);
136-
137133
row_stride = 0;
138134
size_t n_rows = 0;
139-
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
140-
const auto& row_offset = batch.offset.ConstHostVector();
141-
batch.data.SetDevice(ctx->Device());
142-
batch.offset.SetDevice(ctx->Device());
143-
n_rows += batch.Size();
144-
for (auto i = 1ull; i < row_offset.size(); i++) {
145-
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
135+
if (!isDense) {
136+
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
137+
const auto& row_offset = batch.offset.ConstHostVector();
138+
n_rows += batch.Size();
139+
for (auto i = 1ull; i < row_offset.size(); i++) {
140+
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
141+
}
146142
}
143+
} else {
144+
row_stride = nfeatures;
145+
n_rows = dmat->Info().num_row_;
147146
}
148147

149148
const size_t n_offsets = cut.cut_ptrs_.Size() - 1;
150149
const size_t n_index = n_rows * row_stride;
151-
ResizeIndex(n_index, isDense);
150+
ResizeIndex(qu, n_index);
152151

153152
CHECK_GT(cut.cut_values_.Size(), 0U);
154153

155154
if (isDense) {
156155
BinTypeSize curent_bin_size = index.GetBinTypeSize();
157156
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) {
158-
SetIndexData<uint8_t, true>(qu, index.data<uint8_t>(), dmat, nbins, row_stride);
159-
157+
SetIndexData<uint8_t, true>(qu, ctx, index.data<uint8_t>(), dmat);
160158
} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) {
161-
SetIndexData<uint16_t, true>(qu, index.data<uint16_t>(), dmat, nbins, row_stride);
159+
SetIndexData<uint16_t, true>(qu, ctx, index.data<uint16_t>(), dmat);
162160
} else {
163161
CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize);
164-
SetIndexData<uint32_t, true>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
162+
SetIndexData<uint16_t, true>(qu, ctx, index.data<uint16_t>(), dmat);
165163
}
166164
/* For sparse DMatrix we have to store index of feature for each bin
167165
in index field to chose right offset. So offset is nullptr and index is not reduced */
168166
} else {
169167
sort_buff.Resize(qu, n_rows * row_stride * sizeof(uint32_t));
170-
SetIndexData<uint32_t, false>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
168+
SetIndexData<uint32_t, false>(qu, ctx, index.data<uint32_t>(), dmat);
171169
}
172170
}
173171

plugin/sycl/data/gradient_index.h

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,9 @@ struct Index {
3131
Index& operator=(Index&& i) = delete;
3232
void SetBinTypeSize(BinTypeSize binTypeSize) {
3333
binTypeSize_ = binTypeSize;
34-
switch (binTypeSize) {
35-
case BinTypeSize::kUint8BinsTypeSize:
36-
func_ = &GetValueFromUint8;
37-
break;
38-
case BinTypeSize::kUint16BinsTypeSize:
39-
func_ = &GetValueFromUint16;
40-
break;
41-
case BinTypeSize::kUint32BinsTypeSize:
42-
func_ = &GetValueFromUint32;
43-
break;
44-
default:
45-
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
46-
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
47-
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
48-
}
34+
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
35+
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
36+
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
4937
}
5038
BinTypeSize GetBinTypeSize() const {
5139
return binTypeSize_;
@@ -65,8 +53,8 @@ struct Index {
6553
return data_.Size() / (binTypeSize_);
6654
}
6755

68-
void Resize(const size_t nBytesData) {
69-
data_.Resize(qu_, nBytesData);
56+
void Resize(::sycl::queue* qu, const size_t nBytesData) {
57+
data_.Resize(qu, nBytesData);
7058
}
7159

7260
uint8_t* begin() const {
@@ -77,28 +65,9 @@ struct Index {
7765
return data_.End();
7866
}
7967

80-
void setQueue(::sycl::queue* qu) {
81-
qu_ = qu;
82-
}
83-
8468
private:
85-
static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) {
86-
return reinterpret_cast<const uint8_t*>(t)[i];
87-
}
88-
static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) {
89-
return reinterpret_cast<const uint16_t*>(t)[i];
90-
}
91-
static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) {
92-
return reinterpret_cast<const uint32_t*>(t)[i];
93-
}
94-
95-
using Func = uint32_t (*)(const uint8_t*, size_t);
96-
9769
USMVector<uint8_t, MemoryType::on_device> data_;
9870
BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize};
99-
Func func_;
100-
101-
::sycl::queue* qu_;
10271
};
10372

10473
/*!
@@ -116,22 +85,19 @@ struct GHistIndexMatrix {
11685
USMVector<uint8_t, MemoryType::on_device> sort_buff;
11786
/*! \brief The corresponding cuts */
11887
xgboost::common::HistogramCuts cut;
119-
DMatrix* p_fmat;
12088
size_t max_num_bins;
12189
size_t nbins;
12290
size_t nfeatures;
12391
size_t row_stride;
12492

12593
// Create a global histogram matrix based on a given DMatrix device wrapper
126-
void Init(::sycl::queue* qu, Context const * ctx,
127-
DMatrix *dmat, int max_num_bins);
94+
void Init(::sycl::queue* qu, Context const * ctx, DMatrix *dmat, int max_num_bins);
12895

12996
template <typename BinIdxType, bool isDense>
130-
void SetIndexData(::sycl::queue* qu, BinIdxType* index_data,
131-
DMatrix *dmat,
132-
size_t nbins, size_t row_stride);
97+
void SetIndexData(::sycl::queue* qu, Context const * ctx, BinIdxType* index_data,
98+
DMatrix *dmat);
13399

134-
void ResizeIndex(size_t n_index, bool isDense);
100+
void ResizeIndex(::sycl::queue* qu, size_t n_index);
135101

136102
inline void GetFeatureCounts(size_t* counts) const {
137103
auto nfeature = cut.cut_ptrs_.Size() - 1;

plugin/sycl/predictor/predictor.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class Predictor : public xgboost::Predictor {
291291
}
292292

293293
if (num_group == 1) {
294-
float sum = 0.0;
294+
float& sum = out_predictions[row_idx];
295295
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
296296
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
297297
if constexpr (any_missing) {
@@ -300,7 +300,6 @@ class Predictor : public xgboost::Predictor {
300300
sum += GetLeafWeight(first_node, fval_buff_row_ptr);
301301
}
302302
}
303-
out_predictions[row_idx] += sum;
304303
} else {
305304
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
306305
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
@@ -333,7 +332,6 @@ class Predictor : public xgboost::Predictor {
333332
int num_features = dmat->Info().num_col_;
334333

335334
float* out_predictions = out_preds->DevicePointer();
336-
::sycl::event event;
337335
for (auto &batch : dmat->GetBatches<SparsePage>()) {
338336
batch.data.SetDevice(ctx_->Device());
339337
batch.offset.SetDevice(ctx_->Device());
@@ -343,6 +341,7 @@ class Predictor : public xgboost::Predictor {
343341
if (batch_size > 0) {
344342
const auto base_rowid = batch.base_rowid;
345343

344+
::sycl::event event;
346345
if (needs_buffer_update) {
347346
fval_buff.ResizeNoCopy(qu_, num_features * batch_size);
348347
if constexpr (any_missing) {
@@ -354,9 +353,9 @@ class Predictor : public xgboost::Predictor {
354353
row_ptr, batch_size, num_features,
355354
num_group, tree_begin, tree_end);
356355
needs_buffer_update = (batch_size != out_preds->Size());
356+
qu_->wait();
357357
}
358358
}
359-
qu_->wait();
360359
}
361360

362361
mutable USMVector<float, MemoryType::on_device> fval_buff;

tests/cpp/plugin/test_sycl_partition_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void TestPartitioning(float sparsity, int max_bins) {
6767

6868
std::vector<uint8_t> ridx_left(num_rows, 0);
6969
std::vector<uint8_t> ridx_right(num_rows, 0);
70-
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
70+
for (auto &batch : p_fmat->GetBatches<SparsePage>()) {
7171
const auto& data_vec = batch.data.HostVector();
7272
const auto& offset_vec = batch.offset.HostVector();
7373

0 commit comments

Comments
 (0)