Skip to content

Commit d6f7f7a

Browse files
committed
device.
1 parent 6bb3a5d commit d6f7f7a

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/objective/multiclass_obj.cu

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "../common/linalg_op.h"
1515
#include "../common/math.h"
1616
#include "../common/optional_weight.h" // for MakeOptionalWeights
17+
#include "../common/stats.h" // for Mean
1718
#include "../common/transform.h"
1819
#include "xgboost/data.h"
1920
#include "xgboost/json.h"
@@ -197,7 +198,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
197198
*base_score = linalg::Zeros<float>(this->ctx_, n_classes);
198199

199200
std::size_t n = info.labels.Size();
200-
201+
// Calculate probability
201202
auto labels = info.labels.View(ctx_->Device());
202203
auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_);
203204
auto intercept = base_score->View(ctx_->Device());
@@ -210,15 +211,13 @@ class SoftmaxMultiClassObj : public ObjFunction {
210211
CHECK_GE(sum_weight, kRtEps);
211212
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
212213

213-
double sum_intercepts = 0.;
214-
for (std::int64_t ix = 0; ix < n_classes; ix++) {
215-
intercept(ix) = std::log(intercept(ix));
216-
sum_intercepts += intercept(ix);
217-
}
218-
const double mean_intercepts = sum_intercepts / static_cast<double>(n_classes);
219-
for (std::int64_t ix = 0; ix < n_classes; ix++) {
220-
intercept(ix) -= mean_intercepts;
221-
}
214+
// Transform it back to margin
215+
linalg::Vector<float> mean_intercepts;
216+
CHECK_EQ(base_score->Size(), n_classes);
217+
common::Mean(this->ctx_, *base_score, &mean_intercepts);
218+
auto d_mean = mean_intercepts.View(this->ctx_->Device());
219+
TransformKernel(this->ctx_, intercept,
220+
[=] XGBOOST_DEVICE(float v) { return log(v) - d_mean(0); });
222221
}
223222

224223
private:

0 commit comments

Comments
 (0)