Skip to content

Commit 74efac6

Browse files
authored
fp32 fix for objectives calculations (#70)
* fix for sycl iGPU * linting --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent ed470ad commit 74efac6

File tree

9 files changed

+55
-30
lines changed

9 files changed

+55
-30
lines changed

plugin/sycl/common/transform.h

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,24 @@ void LaunchSyclKernel(DeviceOrd device, Functor&& _func, xgboost::common::Range
2020
auto* qu = device_manager.GetQueue(device);
2121

2222
size_t size = *(_range.end());
23-
qu->submit([&](::sycl::handler& cgh) {
24-
cgh.parallel_for<>(::sycl::range<1>(size),
25-
[=](::sycl::id<1> pid) {
26-
const size_t idx = pid[0];
27-
const_cast<Functor&&>(_func)(idx, _spans...);
28-
});
29-
}).wait();
23+
const bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64);
24+
if (has_fp64_support) {
25+
qu->submit([&](::sycl::handler& cgh) {
26+
cgh.parallel_for<>(::sycl::range<1>(size),
27+
[=](::sycl::id<1> pid) {
28+
const size_t idx = pid[0];
29+
const_cast<Functor&&>(_func)(idx, std::true_type(), _spans...);
30+
});
31+
}).wait();
32+
} else {
33+
qu->submit([&](::sycl::handler& cgh) {
34+
cgh.parallel_for<>(::sycl::range<1>(size),
35+
[=](::sycl::id<1> pid) {
36+
const size_t idx = pid[0];
37+
const_cast<Functor&&>(_func)(idx, std::false_type(), _spans...);
38+
});
39+
}).wait();
40+
}
3041
}
3142

3243
} // namespace common

src/common/transform.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ template <typename Functor, typename... SpanType>
3737
__global__ void LaunchCUDAKernel(Functor _func, Range _range,
3838
SpanType... _spans) {
3939
for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) {
40-
_func(i, _spans...);
40+
_func(i, std::true_type(), _spans...);
4141
}
4242
}
4343
#endif // defined(__CUDACC__)
@@ -184,7 +184,8 @@ class Transform {
184184
void LaunchCPU(Functor func, HDV *...vectors) const {
185185
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
186186
SyncHost(vectors...);
187-
ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, UnpackHDV(vectors)...); });
187+
ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, std::true_type(),
188+
UnpackHDV(vectors)...); });
188189
}
189190

190191
private:

src/objective/aft_obj.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AFTObj : public ObjFunction {
4545
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, DeviceOrd device,
4646
bool is_null_weight, float aft_loss_distribution_scale) {
4747
common::Transform<>::Init(
48-
[=] XGBOOST_DEVICE(size_t _idx,
48+
[=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support,
4949
common::Span<GradientPair> _out_gpair,
5050
common::Span<const bst_float> _preds,
5151
common::Span<const bst_float> _labels_lower_bound,
@@ -104,7 +104,7 @@ class AFTObj : public ObjFunction {
104104
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
105105
// Trees give us a prediction in log scale, so exponentiate
106106
common::Transform<>::Init(
107-
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
107+
[] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span<bst_float> _preds) {
108108
_preds[_idx] = exp(_preds[_idx]);
109109
},
110110
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),

src/objective/hinge.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class HingeObj : public FitIntercept {
8585

8686
void PredTransform(HostDeviceVector<float> *io_preds) const override {
8787
common::Transform<>::Init(
88-
[] XGBOOST_DEVICE(std::size_t _idx, common::Span<float> _preds) {
88+
[] XGBOOST_DEVICE(std::size_t _idx, auto has_fp64_support, common::Span<float> _preds) {
8989
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
9090
},
9191
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),

src/objective/multiclass_obj.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
7575
}
7676

7777
common::Transform<>::Init(
78-
[=] XGBOOST_DEVICE(size_t idx,
78+
[=] XGBOOST_DEVICE(size_t idx, auto has_fp64_support,
7979
common::Span<GradientPair> gpair,
8080
common::Span<bst_float const> labels,
8181
common::Span<bst_float const> preds,
@@ -86,8 +86,16 @@ class SoftmaxMultiClassObj : public ObjFunction {
8686
// Part of Softmax function
8787
bst_float wmax = std::numeric_limits<bst_float>::min();
8888
for (auto const i : point) { wmax = fmaxf(i, wmax); }
89-
double wsum = 0.0f;
90-
for (auto const i : point) { wsum += expf(i - wmax); }
89+
90+
float wsum = 0.0f;
91+
if constexpr (has_fp64_support) {
92+
double wsum_fp64 = 0;
93+
for (auto const i : point) { wsum_fp64 += expf(i - wmax); }
94+
wsum = static_cast<float>(wsum_fp64);
95+
} else {
96+
for (auto const i : point) { wsum += expf(i - wmax); }
97+
}
98+
9199
auto label = labels[idx];
92100
if (label < 0 || label >= nclass) {
93101
_label_correct[0] = 0;
@@ -96,11 +104,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
96104
bst_float wt = is_null_weight ? 1.0f : weights[idx];
97105
for (int k = 0; k < nclass; ++k) {
98106
// Computation duplicated to avoid creating a cache.
99-
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
107+
bst_float p = expf(point[k] - wmax) / wsum;
100108
const float eps = 1e-16f;
101-
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps);
109+
const bst_float h = 2.0f * p * (1.0f - p) * wt;
102110
p = label == k ? p - 1.0f : p;
103-
gpair[idx * nclass + k] = GradientPair(p * wt, h);
111+
gpair[idx * nclass + k] = GradientPair(p * wt, h < eps ? eps : h);
104112
}
105113
}, common::Range{0, ndata}, ctx_->Threads(), device)
106114
.Eval(out_gpair->Data(), info.labels.Data(), &preds, &info.weights_, &label_correct_);
@@ -129,7 +137,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
129137
auto device = io_preds->Device();
130138
if (prob) {
131139
common::Transform<>::Init(
132-
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
140+
[=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span<bst_float> _preds) {
133141
common::Span<bst_float> point =
134142
_preds.subspan(_idx * nclass, nclass);
135143
common::Softmax(point.begin(), point.end());
@@ -142,7 +150,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
142150
max_preds.SetDevice(device);
143151
max_preds.Resize(ndata);
144152
common::Transform<>::Init(
145-
[=] XGBOOST_DEVICE(size_t _idx, common::Span<const bst_float> _preds,
153+
[=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support,
154+
common::Span<const bst_float> _preds,
146155
common::Span<bst_float> _max_preds) {
147156
common::Span<const bst_float> point =
148157
_preds.subspan(_idx * nclass, nclass);

src/objective/regression_obj.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ class RegLossObj : public FitInterceptGlmLike {
142142

143143
common::Transform<>::Init(
144144
[block_size, ndata, n_targets] XGBOOST_DEVICE(
145-
size_t data_block_idx, common::Span<float> _additional_input,
145+
size_t data_block_idx, auto has_fp64_support,
146+
common::Span<float> _additional_input,
146147
common::Span<GradientPair> _out_gpair,
147148
common::Span<const bst_float> _preds,
148149
common::Span<const bst_float> _labels,
@@ -179,7 +180,7 @@ class RegLossObj : public FitInterceptGlmLike {
179180

180181
void PredTransform(HostDeviceVector<float> *io_preds) const override {
181182
common::Transform<>::Init(
182-
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
183+
[] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span<float> _preds) {
183184
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
184185
},
185186
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
@@ -360,7 +361,7 @@ class PoissonRegression : public FitInterceptGlmLike {
360361
}
361362
bst_float max_delta_step = param_.max_delta_step;
362363
common::Transform<>::Init(
363-
[=] XGBOOST_DEVICE(size_t _idx,
364+
[=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support,
364365
common::Span<int> _label_correct,
365366
common::Span<GradientPair> _out_gpair,
366367
common::Span<const bst_float> _preds,
@@ -387,7 +388,7 @@ class PoissonRegression : public FitInterceptGlmLike {
387388
}
388389
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
389390
common::Transform<>::Init(
390-
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
391+
[] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span<bst_float> _preds) {
391392
_preds[_idx] = expf(_preds[_idx]);
392393
},
393394
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
@@ -566,7 +567,7 @@ class TweedieRegression : public FitInterceptGlmLike {
566567

567568
const float rho = param_.tweedie_variance_power;
568569
common::Transform<>::Init(
569-
[=] XGBOOST_DEVICE(size_t _idx,
570+
[=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support,
570571
common::Span<int> _label_correct,
571572
common::Span<GradientPair> _out_gpair,
572573
common::Span<const bst_float> _preds,
@@ -597,7 +598,7 @@ class TweedieRegression : public FitInterceptGlmLike {
597598
}
598599
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
599600
common::Transform<>::Init(
600-
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
601+
[] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span<bst_float> _preds) {
601602
_preds[_idx] = expf(_preds[_idx]);
602603
},
603604
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),

src/tree/split_evaluator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class TreeEvaluator {
180180
}
181181

182182
common::Transform<>::Init(
183-
[=] XGBOOST_DEVICE(size_t, common::Span<float> lower,
183+
[=] XGBOOST_DEVICE(size_t, auto has_fp64_support,
184+
common::Span<float> lower,
184185
common::Span<float> upper,
185186
common::Span<int> monotone) {
186187
lower[leftid] = lower[nodeid];

tests/cpp/common/test_transform_range.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ constexpr DeviceOrd TransformDevice() {
2525

2626
template <typename T>
2727
struct TestTransformRange {
28-
void XGBOOST_DEVICE operator()(std::size_t _idx, Span<float> _out, Span<const float> _in) {
28+
template <class kBoolConst>
29+
void XGBOOST_DEVICE operator()(std::size_t _idx, kBoolConst has_fp64_support, Span<float> _out, Span<const float> _in) {
2930
_out[_idx] = _in[_idx];
3031
}
3132
};
@@ -59,7 +60,7 @@ TEST(TransformDeathTest, Exception) {
5960
const HostDeviceVector<float> in_vec{h_in, DeviceOrd::CPU()};
6061
EXPECT_DEATH(
6162
{
62-
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
63+
Transform<>::Init([](size_t idx, auto has_fp64_support, common::Span<float const> _in) { _in[idx + 1]; },
6364
Range(0, static_cast<Range::DifferenceType>(kSize)), AllThreadsForTest(),
6465
DeviceOrd::CPU())
6566
.Eval(&in_vec);

tests/cpp/plugin/test_sycl_transform_range.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace xgboost::common {
1919

2020
template <typename T>
2121
struct TestTransformRange {
22-
void operator()(std::size_t _idx, Span<float> _out, Span<const float> _in) {
22+
template <class kBoolConst>
23+
void operator()(std::size_t _idx, kBoolConst has_fp64_support, Span<float> _out, Span<const float> _in) {
2324
_out[_idx] = _in[_idx];
2425
}
2526
};

0 commit comments

Comments
 (0)