1414#include " ../common/linalg_op.h"
1515#include " ../common/math.h"
1616#include " ../common/optional_weight.h" // for MakeOptionalWeights
17+ #include " ../common/stats.h" // for Mean
1718#include " ../common/transform.h"
1819#include " xgboost/data.h"
1920#include " xgboost/json.h"
@@ -197,7 +198,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
197198 *base_score = linalg::Zeros<float >(this ->ctx_ , n_classes);
198199
199200 std::size_t n = info.labels .Size ();
200-
201+ // Calculate probability
201202 auto labels = info.labels .View (ctx_->Device ());
202203 auto weights = common::MakeOptionalWeights (this ->ctx_ ->Device (), info.weights_ );
203204 auto intercept = base_score->View (ctx_->Device ());
@@ -210,15 +211,13 @@ class SoftmaxMultiClassObj : public ObjFunction {
210211 CHECK_GE (sum_weight, kRtEps );
211212 linalg::VecScaDiv (this ->ctx_ , intercept, sum_weight);
212213
213- double sum_intercepts = 0 .;
214- for (std::int64_t ix = 0 ; ix < n_classes; ix++) {
215- intercept (ix) = std::log (intercept (ix));
216- sum_intercepts += intercept (ix);
217- }
218- const double mean_intercepts = sum_intercepts / static_cast <double >(n_classes);
219- for (std::int64_t ix = 0 ; ix < n_classes; ix++) {
220- intercept (ix) -= mean_intercepts;
221- }
214+ // Transform it back to margin
215+ linalg::Vector<float > mean_intercepts;
216+ CHECK_EQ (base_score->Size (), n_classes);
217+ common::Mean (this ->ctx_ , *base_score, &mean_intercepts);
218+ auto d_mean = mean_intercepts.View (this ->ctx_ ->Device ());
219+ TransformKernel (this ->ctx_ , intercept,
220+ [=] XGBOOST_DEVICE (float v) { return log (v) - d_mean (0 ); });
222221 }
223222
224223 private:
0 commit comments