@@ -50,29 +50,28 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
50
50
51
51
template <typename BinIdxType, bool isDense>
52
52
void GHistIndexMatrix::SetIndexData (::sycl::queue* qu,
53
- Context const * ctx,
54
53
BinIdxType* index_data,
55
- DMatrix *dmat) {
54
+ DMatrix *dmat,
55
+ size_t nbins,
56
+ size_t row_stride) {
56
57
if (nbins == 0 ) return ;
57
58
const bst_float* cut_values = cut.cut_values_ .ConstDevicePointer ();
58
59
const uint32_t * cut_ptrs = cut.cut_ptrs_ .ConstDevicePointer ();
59
60
size_t * hit_count_ptr = hit_count.DevicePointer ();
60
61
61
62
BinIdxType* sort_data = reinterpret_cast <BinIdxType*>(sort_buff.Data ());
62
63
64
+ ::sycl::event event;
63
65
for (auto &batch : dmat->GetBatches <SparsePage>()) {
64
- batch.data .SetDevice (ctx->Device ());
65
- batch.offset .SetDevice (ctx->Device ());
66
-
67
- const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68
- const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69
- size_t batch_size = batch.Size ();
70
- if (batch_size > 0 ) {
71
- const auto base_rowid = batch.base_rowid ;
72
- size_t row_stride = this ->row_stride ;
73
- size_t nbins = this ->nbins ;
74
- qu->submit ([&](::sycl::handler& cgh) {
75
- cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
66
+ for (auto &batch : dmat->GetBatches <SparsePage>()) {
67
+ const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68
+ const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69
+ size_t batch_size = batch.Size ();
70
+ if (batch_size > 0 ) {
71
+ const auto base_rowid = batch.base_rowid ;
72
+ event = qu->submit ([&](::sycl::handler& cgh) {
73
+ cgh.depends_on (event);
74
+ cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
76
75
const size_t i = pid.get_id (0 );
77
76
const size_t ibegin = offset_vec[i];
78
77
const size_t iend = offset_vec[i + 1 ];
@@ -93,22 +92,23 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
93
92
}
94
93
});
95
94
});
96
- qu-> wait ();
95
+ }
97
96
}
98
97
}
98
+ qu->wait ();
99
99
}
100
100
101
- void GHistIndexMatrix::ResizeIndex (::sycl::queue* qu, size_t n_index ) {
102
- if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense_ ) {
101
+ void GHistIndexMatrix::ResizeIndex (size_t n_index, bool isDense ) {
102
+ if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense ) {
103
103
index.SetBinTypeSize (BinTypeSize::kUint8BinsTypeSize );
104
- index.Resize (qu, (sizeof (uint8_t )) * n_index);
104
+ index.Resize ((sizeof (uint8_t )) * n_index);
105
105
} else if ((max_num_bins - 1 > static_cast <int >(std::numeric_limits<uint8_t >::max ()) &&
106
- max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense_ ) {
106
+ max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense ) {
107
107
index.SetBinTypeSize (BinTypeSize::kUint16BinsTypeSize );
108
- index.Resize (qu, (sizeof (uint16_t )) * n_index);
108
+ index.Resize ((sizeof (uint16_t )) * n_index);
109
109
} else {
110
110
index.SetBinTypeSize (BinTypeSize::kUint32BinsTypeSize );
111
- index.Resize (qu, (sizeof (uint32_t )) * n_index);
111
+ index.Resize ((sizeof (uint32_t )) * n_index);
112
112
}
113
113
}
114
114
@@ -122,50 +122,52 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122
122
cut.SetDevice (ctx->Device ());
123
123
124
124
max_num_bins = max_bins;
125
- nbins = cut.Ptrs ().back ();
125
+ const uint32_t nbins = cut.Ptrs ().back ();
126
+ this ->nbins = nbins;
126
127
127
128
hit_count.SetDevice (ctx->Device ());
128
129
hit_count.Resize (nbins, 0 );
129
130
131
+ this ->p_fmat = dmat;
130
132
const bool isDense = dmat->IsDense ();
131
133
this ->isDense_ = isDense;
132
134
135
+ index.setQueue (qu);
136
+
133
137
row_stride = 0 ;
134
138
size_t n_rows = 0 ;
135
- if (!isDense ) {
136
- for ( const auto & batch : dmat-> GetBatches <SparsePage>()) {
137
- const auto & row_offset = batch.offset . ConstHostVector ( );
138
- n_rows += batch.Size ( );
139
- for ( auto i = 1ull ; i < row_offset. size (); i++) {
140
- row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
141
- }
139
+ for ( const auto & batch : dmat-> GetBatches <SparsePage>() ) {
140
+ const auto & row_offset = batch. offset . ConstHostVector ();
141
+ batch.data . SetDevice (ctx-> Device () );
142
+ batch.offset . SetDevice (ctx-> Device () );
143
+ n_rows += batch. Size ();
144
+ for ( auto i = 1ull ; i < row_offset. size (); i++) {
145
+ row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
142
146
}
143
- } else {
144
- row_stride = nfeatures;
145
- n_rows = dmat->Info ().num_row_ ;
146
147
}
147
148
148
149
const size_t n_offsets = cut.cut_ptrs_ .Size () - 1 ;
149
150
const size_t n_index = n_rows * row_stride;
150
- ResizeIndex (qu, n_index );
151
+ ResizeIndex (n_index, isDense );
151
152
152
153
CHECK_GT (cut.cut_values_ .Size (), 0U );
153
154
154
155
if (isDense) {
155
156
BinTypeSize curent_bin_size = index.GetBinTypeSize ();
156
157
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize ) {
157
- SetIndexData<uint8_t , true >(qu, ctx, index.data <uint8_t >(), dmat);
158
+ SetIndexData<uint8_t , true >(qu, index.data <uint8_t >(), dmat, nbins, row_stride);
159
+
158
160
} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize ) {
159
- SetIndexData<uint16_t , true >(qu, ctx, index.data <uint16_t >(), dmat);
161
+ SetIndexData<uint16_t , true >(qu, index.data <uint16_t >(), dmat, nbins, row_stride );
160
162
} else {
161
163
CHECK_EQ (curent_bin_size, BinTypeSize::kUint32BinsTypeSize );
162
- SetIndexData<uint32_t , true >(qu, ctx, index.data <uint32_t >(), dmat);
164
+ SetIndexData<uint32_t , true >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
163
165
}
164
166
/* For sparse DMatrix we have to store index of feature for each bin
165
167
in index field to chose right offset. So offset is nullptr and index is not reduced */
166
168
} else {
167
169
sort_buff.Resize (qu, n_rows * row_stride * sizeof (uint32_t ));
168
- SetIndexData<uint32_t , false >(qu, ctx, index.data <uint32_t >(), dmat);
170
+ SetIndexData<uint32_t , false >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
169
171
}
170
172
}
171
173
0 commit comments