Skip to content

Commit 7025754

Browse files
committed
Avoid using mean intercept for rmsle.
1 parent e17f16a commit 7025754

File tree

3 files changed

+90
-48
lines changed

3 files changed

+90
-48
lines changed

src/objective/quantile_obj.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost contributors
2+
* Copyright 2023-2025, XGBoost contributors
33
*/
44
#include <array> // std::array
55
#include <cstddef> // std::size_t
@@ -10,15 +10,13 @@
1010
#include "../common/quantile_loss_utils.h" // QuantileLossParam
1111
#include "../common/stats.h" // Quantile,WeightedQuantile
1212
#include "adaptive.h" // UpdateTreeLeaf
13-
#include "dmlc/parameter.h" // DMLC_DECLARE_PARAMETER
1413
#include "init_estimation.h" // CheckInitInputs
1514
#include "xgboost/base.h" // GradientPair,XGBOOST_DEVICE,bst_target_t
1615
#include "xgboost/data.h" // MetaInfo
1716
#include "xgboost/host_device_vector.h" // HostDeviceVector
1817
#include "xgboost/json.h" // Json,String,ToJson,FromJson
1918
#include "xgboost/linalg.h" // Tensor,MakeTensorView,MakeVec
2019
#include "xgboost/objective.h" // ObjFunction
21-
#include "xgboost/parameter.h" // XGBoostParameter
2220

2321
#if defined(XGBOOST_USE_CUDA)
2422

src/objective/regression_loss.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2023 by XGBoost contributors
2+
* Copyright 2017-2025, XGBoost contributors
33
*/
44
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
55
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
@@ -9,7 +9,6 @@
99
#include <cmath>
1010

1111
#include "../common/math.h"
12-
#include "xgboost/data.h" // MetaInfo
1312
#include "xgboost/logging.h"
1413
#include "xgboost/task.h" // ObjInfo
1514

src/objective/regression_obj.cu

Lines changed: 88 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2015-2024, XGBoost Contributors
2+
* Copyright 2015-2025, XGBoost Contributors
33
* \file regression_obj.cu
44
* \brief Definition of single-value regression and classification objectives.
55
* \author Tianqi Chen, Kailong Chen
@@ -9,7 +9,6 @@
99
#include <algorithm>
1010
#include <cmath>
1111
#include <cstdint> // std::int32_t
12-
#include <memory>
1312
#include <vector>
1413

1514
#include "../common/common.h"
@@ -53,54 +52,55 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& pre
5352
CheckInitInputs(info);
5453
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
5554
}
55+
56+
template <typename Loss>
57+
void ValidateLabel(Context const* ctx, MetaInfo const& info) {
58+
auto label = info.labels.View(ctx->Device());
59+
auto valid = ctx->DispatchDevice(
60+
[&] {
61+
return std::all_of(linalg::cbegin(label), linalg::cend(label),
62+
[](float y) -> bool { return Loss::CheckLabel(y); });
63+
},
64+
[&] {
65+
#if defined(XGBOOST_USE_CUDA)
66+
auto cuctx = ctx->CUDACtx();
67+
auto it = dh::MakeTransformIterator<bool>(
68+
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool {
69+
auto [m, n] = linalg::UnravelIndex(i, label.Shape());
70+
return Loss::CheckLabel(label(m, n));
71+
});
72+
return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{});
73+
#else
74+
common::AssertGPUSupport();
75+
return false;
76+
#endif // defined(XGBOOST_USE_CUDA)
77+
},
78+
[&] {
79+
#if defined(XGBOOST_USE_SYCL)
80+
return sycl::linalg::Validate(ctx_->Device(), label,
81+
[](float y) -> bool { return Loss::CheckLabel(y); });
82+
#else
83+
common::AssertSYCLSupport();
84+
return false;
85+
#endif // defined(XGBOOST_USE_SYCL)
86+
});
87+
if (!valid) {
88+
LOG(FATAL) << Loss::LabelErrorMsg();
89+
}
90+
}
5691
} // anonymous namespace
5792

5893
#if defined(XGBOOST_USE_CUDA)
5994
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
6095
#endif // defined(XGBOOST_USE_CUDA)
6196

62-
63-
6497
template<typename Loss>
6598
class RegLossObj : public FitInterceptGlmLike {
6699
protected:
67100
HostDeviceVector<float> additional_input_;
68101

69102
public:
70-
void ValidateLabel(MetaInfo const& info) {
71-
auto label = info.labels.View(ctx_->Device());
72-
auto valid = ctx_->DispatchDevice(
73-
[&] {
74-
return std::all_of(linalg::cbegin(label), linalg::cend(label),
75-
[](float y) -> bool { return Loss::CheckLabel(y); });
76-
},
77-
[&] {
78-
#if defined(XGBOOST_USE_CUDA)
79-
auto cuctx = ctx_->CUDACtx();
80-
auto it = dh::MakeTransformIterator<bool>(
81-
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool {
82-
auto [m, n] = linalg::UnravelIndex(i, label.Shape());
83-
return Loss::CheckLabel(label(m, n));
84-
});
85-
return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{});
86-
#else
87-
common::AssertGPUSupport();
88-
return false;
89-
#endif // defined(XGBOOST_USE_CUDA)
90-
},
91-
[&] {
92-
#if defined(XGBOOST_USE_SYCL)
93-
return sycl::linalg::Validate(ctx_->Device(), label,
94-
[](float y) -> bool { return Loss::CheckLabel(y); });
95-
#else
96-
common::AssertSYCLSupport();
97-
return false;
98-
#endif // defined(XGBOOST_USE_SYCL)
99-
});
100-
if (!valid) {
101-
LOG(FATAL) << Loss::LabelErrorMsg();
102-
}
103-
}
103+
104104
// 0 - scale_pos_weight, 1 - is_null_weight
105105
RegLossObj(): additional_input_(2) {}
106106

@@ -117,7 +117,7 @@ class RegLossObj : public FitInterceptGlmLike {
117117
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
118118
CheckRegInputs(info, preds);
119119
if (iter == 0) {
120-
ValidateLabel(info);
120+
ValidateLabel<Loss>(this->ctx_, info);
121121
}
122122

123123
size_t const ndata = preds.Size();
@@ -224,10 +224,6 @@ XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name())
224224
.describe("Regression with squared error.")
225225
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
226226

227-
XGBOOST_REGISTER_OBJECTIVE(SquareLogError, SquaredLogError::Name())
228-
.describe("Regression with root mean squared logarithmic error.")
229-
.set_body([]() { return new RegLossObj<SquaredLogError>(); });
230-
231227
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
232228
.describe("Logistic regression for probability regression task.")
233229
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
@@ -253,6 +249,55 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
253249
return new RegLossObj<LinearSquareLoss>(); });
254250
// End deprecated
255251

252+
class SquaredLogErrorRegression : public FitIntercept {
253+
public:
254+
static auto Name() { return SquaredLogError::Name(); }
255+
256+
void Configure(Args const&) override {}
257+
[[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
258+
[[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
259+
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
260+
}
261+
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info,
262+
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
263+
if (iter == 0) {
264+
ValidateLabel<SquaredLogError>(this->ctx_, info);
265+
}
266+
auto labels = info.labels.View(ctx_->Device());
267+
268+
out_gpair->SetDevice(ctx_->Device());
269+
out_gpair->Reshape(info.num_row_, this->Targets(info));
270+
auto gpair = out_gpair->View(ctx_->Device());
271+
272+
preds.SetDevice(ctx_->Device());
273+
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
274+
275+
info.weights_.SetDevice(ctx_->Device());
276+
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
277+
: info.weights_.ConstDeviceSpan()};
278+
linalg::ElementWiseKernel(this->ctx_, labels,
279+
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
280+
auto p = predt(i, j);
281+
auto y = labels(i, j);
282+
auto w = weight[i];
283+
auto grad = SquaredLogError::FirstOrderGradient(p, y);
284+
auto hess = SquaredLogError::SecondOrderGradient(p, y);
285+
gpair(i) = {grad * w, hess * w};
286+
});
287+
}
288+
[[nodiscard]] const char* DefaultEvalMetric() const override { return "rmsle"; }
289+
290+
void SaveConfig(Json* p_out) const override {
291+
auto& out = *p_out;
292+
out["name"] = String(Name());
293+
}
294+
void LoadConfig(Json const&) override {}
295+
};
296+
297+
XGBOOST_REGISTER_OBJECTIVE(SquaredLogErrorRegression, SquaredLogErrorRegression::Name())
298+
.describe("Root mean squared log error.")
299+
.set_body([]() { return new SquaredLogErrorRegression(); });
300+
256301
class PseudoHuberRegression : public FitIntercept {
257302
PesudoHuberParam param_;
258303

0 commit comments

Comments
 (0)