Skip to content

Commit f53fc15

Browse files
committed
Change order.
1 parent d6f7f7a commit f53fc15

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

src/common/stats.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@ void Median(Context const* ctx, linalg::Matrix<float> const& t,
4545
}
4646
}
4747

48-
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out) {
49-
v.SetDevice(ctx->Device());
48+
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<float>* out) {
5049
out->SetDevice(ctx->Device());
5150
out->Reshape(1);
5251

5352
if (ctx->IsCUDA()) {
54-
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
53+
cuda_impl::Mean(ctx, v, out->View(ctx->Device()));
5554
} else {
56-
auto h_v = v.HostView();
55+
auto h_v = v;
5756
float n = v.Size();
5857
MemStackAllocator<float, DefaultMaxThreads()> tloc(ctx->Threads(), 0.0f);
5958
ParallelFor(v.Size(), ctx->Threads(),

src/common/stats.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void Median(Context const* ctx, linalg::Matrix<float> const& t,
149149
/**
150150
* @brief Calculate the mean value of a vector.
151151
*/
152-
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
152+
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<float>* out);
153153

154154
/**
155155
* @brief Calculate the mean value for the first axis.

src/objective/multiclass_obj.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,15 @@ class SoftmaxMultiClassObj : public ObjFunction {
210210
collective::SafeColl(status);
211211
CHECK_GE(sum_weight, kRtEps);
212212
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
213+
CHECK_EQ(base_score->Size(), n_classes);
213214

214215
// Transform it back to margin
215-
linalg::Vector<float> mean_intercepts;
216-
CHECK_EQ(base_score->Size(), n_classes);
217-
common::Mean(this->ctx_, *base_score, &mean_intercepts);
218-
auto d_mean = mean_intercepts.View(this->ctx_->Device());
219-
TransformKernel(this->ctx_, intercept,
220-
[=] XGBOOST_DEVICE(float v) { return log(v) - d_mean(0); });
216+
// ln(v) - E[ln(v)]
217+
linalg::Vector<float> mean;
218+
linalg::LogE(this->ctx_, intercept);
219+
common::Mean(this->ctx_, intercept, &mean);
220+
auto d_mean = mean.View(this->ctx_->Device());
221+
TransformKernel(this->ctx_, intercept, [=] XGBOOST_DEVICE(float v) { return v - d_mean(0); });
221222
}
222223

223224
private:

tests/cpp/common/test_stats.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void TestMean(Context const* ctx) {
114114
float mean = nf * (nf - 1) / 2 / n;
115115

116116
linalg::Vector<float> res{{1}, ctx->Device()};
117-
Mean(ctx, data, &res);
117+
Mean(ctx, data.View(ctx->Device()), &res);
118118
auto h_res = res.HostView();
119119
ASSERT_EQ(h_res.Size(), 1);
120120
ASSERT_EQ(mean, h_res(0));

0 commit comments

Comments
 (0)