@@ -50,28 +50,29 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
50
50
51
51
template <typename BinIdxType, bool isDense>
52
52
void GHistIndexMatrix::SetIndexData (::sycl::queue* qu,
53
+ Context const * ctx,
53
54
BinIdxType* index_data,
54
- DMatrix *dmat,
55
- size_t nbins,
56
- size_t row_stride) {
55
+ DMatrix *dmat) {
57
56
if (nbins == 0 ) return ;
58
57
const bst_float* cut_values = cut.cut_values_ .ConstDevicePointer ();
59
58
const uint32_t * cut_ptrs = cut.cut_ptrs_ .ConstDevicePointer ();
60
59
size_t * hit_count_ptr = hit_count.DevicePointer ();
61
60
62
61
BinIdxType* sort_data = reinterpret_cast <BinIdxType*>(sort_buff.Data ());
63
62
64
- ::sycl::event event;
65
63
for (auto &batch : dmat->GetBatches <SparsePage>()) {
66
- for (auto &batch : dmat->GetBatches <SparsePage>()) {
67
- const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68
- const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69
- size_t batch_size = batch.Size ();
70
- if (batch_size > 0 ) {
71
- const auto base_rowid = batch.base_rowid ;
72
- event = qu->submit ([&](::sycl::handler& cgh) {
73
- cgh.depends_on (event);
74
- cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
64
+ batch.data .SetDevice (ctx->Device ());
65
+ batch.offset .SetDevice (ctx->Device ());
66
+
67
+ const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68
+ const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69
+ size_t batch_size = batch.Size ();
70
+ if (batch_size > 0 ) {
71
+ const auto base_rowid = batch.base_rowid ;
72
+ size_t row_stride = this ->row_stride ;
73
+ size_t nbins = this ->nbins ;
74
+ qu->submit ([&](::sycl::handler& cgh) {
75
+ cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
75
76
const size_t i = pid.get_id (0 );
76
77
const size_t ibegin = offset_vec[i];
77
78
const size_t iend = offset_vec[i + 1 ];
@@ -92,23 +93,22 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
92
93
}
93
94
});
94
95
});
95
- }
96
+ qu-> wait ();
96
97
}
97
98
}
98
- qu->wait ();
99
99
}
100
100
101
- void GHistIndexMatrix::ResizeIndex (size_t n_index, bool isDense ) {
102
- if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense ) {
101
+ void GHistIndexMatrix::ResizeIndex (::sycl::queue* qu, size_t n_index ) {
102
+ if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense_ ) {
103
103
index.SetBinTypeSize (BinTypeSize::kUint8BinsTypeSize );
104
- index.Resize ((sizeof (uint8_t )) * n_index);
104
+ index.Resize (qu, (sizeof (uint8_t )) * n_index);
105
105
} else if ((max_num_bins - 1 > static_cast <int >(std::numeric_limits<uint8_t >::max ()) &&
106
- max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense ) {
106
+ max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense_ ) {
107
107
index.SetBinTypeSize (BinTypeSize::kUint16BinsTypeSize );
108
- index.Resize ((sizeof (uint16_t )) * n_index);
108
+ index.Resize (qu, (sizeof (uint16_t )) * n_index);
109
109
} else {
110
110
index.SetBinTypeSize (BinTypeSize::kUint32BinsTypeSize );
111
- index.Resize ((sizeof (uint32_t )) * n_index);
111
+ index.Resize (qu, (sizeof (uint32_t )) * n_index);
112
112
}
113
113
}
114
114
@@ -122,52 +122,50 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122
122
cut.SetDevice (ctx->Device ());
123
123
124
124
max_num_bins = max_bins;
125
- const uint32_t nbins = cut.Ptrs ().back ();
126
- this ->nbins = nbins;
125
+ nbins = cut.Ptrs ().back ();
127
126
128
127
hit_count.SetDevice (ctx->Device ());
129
128
hit_count.Resize (nbins, 0 );
130
129
131
- this ->p_fmat = dmat;
132
130
const bool isDense = dmat->IsDense ();
133
131
this ->isDense_ = isDense;
134
132
135
- index.setQueue (qu);
136
-
137
133
row_stride = 0 ;
138
134
size_t n_rows = 0 ;
139
- for ( const auto & batch : dmat-> GetBatches <SparsePage>() ) {
140
- const auto & row_offset = batch. offset . ConstHostVector ();
141
- batch.data . SetDevice (ctx-> Device () );
142
- batch.offset . SetDevice (ctx-> Device () );
143
- n_rows += batch. Size ();
144
- for ( auto i = 1ull ; i < row_offset. size (); i++) {
145
- row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
135
+ if (!isDense ) {
136
+ for ( const auto & batch : dmat-> GetBatches <SparsePage>()) {
137
+ const auto & row_offset = batch.offset . ConstHostVector ( );
138
+ n_rows += batch.Size ( );
139
+ for ( auto i = 1ull ; i < row_offset. size (); i++) {
140
+ row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
141
+ }
146
142
}
143
+ } else {
144
+ row_stride = nfeatures;
145
+ n_rows = dmat->Info ().num_row_ ;
147
146
}
148
147
149
148
const size_t n_offsets = cut.cut_ptrs_ .Size () - 1 ;
150
149
const size_t n_index = n_rows * row_stride;
151
- ResizeIndex (n_index, isDense );
150
+ ResizeIndex (qu, n_index );
152
151
153
152
CHECK_GT (cut.cut_values_ .Size (), 0U );
154
153
155
154
if (isDense) {
156
155
BinTypeSize curent_bin_size = index.GetBinTypeSize ();
157
156
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize ) {
158
- SetIndexData<uint8_t , true >(qu, index.data <uint8_t >(), dmat, nbins, row_stride);
159
-
157
+ SetIndexData<uint8_t , true >(qu, ctx, index.data <uint8_t >(), dmat);
160
158
} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize ) {
161
- SetIndexData<uint16_t , true >(qu, index.data <uint16_t >(), dmat, nbins, row_stride );
159
+ SetIndexData<uint16_t , true >(qu, ctx, index.data <uint16_t >(), dmat);
162
160
} else {
163
161
CHECK_EQ (curent_bin_size, BinTypeSize::kUint32BinsTypeSize );
164
- SetIndexData<uint32_t , true >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
162
+ SetIndexData<uint16_t , true >(qu, ctx, index.data <uint16_t >(), dmat);
165
163
}
166
164
/* For sparse DMatrix we have to store index of feature for each bin
167
165
in index field to chose right offset. So offset is nullptr and index is not reduced */
168
166
} else {
169
167
sort_buff.Resize (qu, n_rows * row_stride * sizeof (uint32_t ));
170
- SetIndexData<uint32_t , false >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
168
+ SetIndexData<uint32_t , false >(qu, ctx, index.data <uint32_t >(), dmat);
171
169
}
172
170
}
173
171
0 commit comments