Skip to content

Commit 7ac273a

Browse files
committed
Cleanup the tests.
1 parent ac31340 commit 7ac273a

File tree

8 files changed

+143
-294
lines changed

8 files changed

+143
-294
lines changed

src/common/ranking_utils.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
*/
44
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
55
#define XGBOOST_COMMON_RANKING_UTILS_H_
6-
#include <algorithm> // for min
7-
#include <cmath> // for log2, fabs, floor
8-
#include <cstddef> // for size_t
9-
#include <cstdint> // for uint32_t, uint8_t, int32_t
10-
#include <limits> // for numeric_limits
11-
#include <string> // for char_traits, string
12-
#include <vector> // for vector
6+
#include <algorithm> // for min
7+
#include <cmath> // for log2, fabs, floor
8+
#include <cstddef> // for size_t
9+
#include <cstdint> // for uint32_t, uint8_t, int32_t
10+
#include <limits> // for numeric_limits
11+
#include <string> // for char_traits, string
12+
#include <vector> // for vector
1313

1414
#include "./math.h" // for CloseTo
1515
#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD
@@ -71,8 +71,8 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
7171
// pairs
7272
// should be accessed by getter for auto configuration.
7373
// nolint so that we can keep the string name.
74-
PairMethod lambdarank_pair_method; // NOLINT
75-
std::size_t lambdarank_num_pair_per_sample; // NOLINT
74+
PairMethod lambdarank_pair_method; // NOLINT
75+
std::size_t lambdarank_num_pair_per_sample; // NOLINT
7676

7777
public:
7878
static constexpr position_t NotSet() { return std::numeric_limits<position_t>::max(); }

src/objective/init_estimation.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
#include "xgboost/linalg.h" // Tensor,Vector
1515
#include "xgboost/task.h" // ObjInfo
1616

17-
namespace xgboost {
18-
namespace obj {
17+
namespace xgboost::obj {
1918
void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const {
2019
if (this->Task().task == ObjInfo::kRegression) {
2120
CheckInitInputs(info);
@@ -31,14 +30,13 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
3130
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
3231
new_obj->LoadConfig(config);
3332
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
33+
3434
bst_target_t n_targets = this->Targets(info);
3535
linalg::Vector<float> leaf_weight;
3636
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
37-
3837
// workaround, we don't support multi-target due to binary model serialization for
3938
// base margin.
4039
common::Mean(this->ctx_, leaf_weight, base_score);
4140
this->PredTransform(base_score->Data());
4241
}
43-
} // namespace obj
44-
} // namespace xgboost
42+
} // namespace xgboost::obj

src/objective/lambdarank_obj.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,9 @@ void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
402402
auto g_label = label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
403403

404404
// The number of relevant documents at each position
405-
g_n_rel[0] = label(g_rank[0]);
405+
g_n_rel[0] = g_label(g_rank[0]);
406406
for (std::size_t k = 1; k < g_rank.size(); ++k) {
407-
g_n_rel[k] = g_n_rel[k - 1] + label(g_rank[k]);
407+
g_n_rel[k] = g_n_rel[k - 1] + g_label(g_rank[k]);
408408
}
409409

410410
// \sum l_k/k

src/objective/lambdarank_obj.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t
417417
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
418418
auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
419419
auto const* cuctx = ctx->CUDACtx();
420+
420421
{
421422
// calculate number of relevant documents
422423
auto val_it = dh::MakeTransformIterator<double>(

src/objective/lambdarank_obj.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
*/
44
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
55
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
6-
#include <algorithm> // for min, max
7-
#include <cassert> // for assert
8-
#include <cmath> // for log, abs
9-
#include <cstddef> // for size_t
10-
#include <functional> // for greater
11-
#include <memory> // for shared_ptr
12-
#include <random> // for minstd_rand, uniform_int_distribution
13-
#include <vector> // for vector
6+
#include <algorithm> // for min, max
7+
#include <cassert> // for assert
8+
#include <cmath> // for log, abs
9+
#include <cstddef> // for size_t
10+
#include <functional> // for greater
11+
#include <memory> // for shared_ptr
12+
#include <random> // for minstd_rand, uniform_int_distribution
13+
#include <vector> // for vector
1414

1515
#include "../common/algorithm.h" // for ArgSort
1616
#include "../common/math.h" // for Sigmoid
@@ -24,8 +24,7 @@
2424
#include "xgboost/logging.h" // for CHECK_EQ
2525
#include "xgboost/span.h" // for Span
2626

27-
namespace xgboost {
28-
namespace obj {
27+
namespace xgboost::obj {
2928
template <bool exp>
3029
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low,
3130
double inv_IDCG, common::Span<double const> discount) {
@@ -51,6 +50,7 @@ XGBOOST_DEVICE inline double DeltaMAP(float y_high, float y_low, std::size_t ran
5150
double r_l = static_cast<double>(rank_low) + 1.0;
5251
double delta{0.0};
5352
double n_total_relevances = n_rel.back();
53+
assert(n_total_relevances > 0.0);
5454
auto m = n_rel[rank_low];
5555
double n = n_rel[rank_high];
5656

@@ -258,6 +258,5 @@ void MakePairs(Context const* ctx, std::int32_t iter,
258258
}
259259
}
260260
}
261-
} // namespace obj
262-
} // namespace xgboost
261+
} // namespace xgboost::obj
263262
#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_

tests/cpp/objective/test_lambdarank_obj.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,80 @@ TEST(LambdaRank, MakePair) {
163163
}
164164
}
165165

166+
void TestMAPStat(Context const* ctx) {
167+
auto p_fmat = EmptyDMatrix();
168+
MetaInfo& info = p_fmat->Info();
169+
ltr::LambdaRankParam param;
170+
param.UpdateAllowUnknown(Args{});
171+
172+
{
173+
std::vector<float> h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
174+
info.labels.Reshape(h_data.size(), 1);
175+
info.labels.Data()->HostVector() = h_data;
176+
info.num_row_ = h_data.size();
177+
178+
HostDeviceVector<float> predt;
179+
auto& h_predt = predt.HostVector();
180+
h_predt.resize(h_data.size());
181+
std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f);
182+
183+
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
184+
185+
predt.SetDevice(ctx->gpu_id);
186+
auto rank_idx =
187+
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
188+
189+
if (ctx->IsCPU()) {
190+
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
191+
p_cache);
192+
} else {
193+
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
194+
}
195+
196+
Context cpu_ctx;
197+
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
198+
auto acc = p_cache->Acc(&cpu_ctx);
199+
200+
ASSERT_EQ(n_rel[0], 1.0);
201+
ASSERT_EQ(acc[0], 1.0);
202+
203+
ASSERT_EQ(n_rel.back(), h_data.size() - 1.0);
204+
ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps);
205+
}
206+
{
207+
info.labels.Reshape(16);
208+
auto& h_label = info.labels.Data()->HostVector();
209+
info.group_ptr_ = {0, 8, 16};
210+
info.num_row_ = info.labels.Shape(0);
211+
212+
std::fill_n(h_label.begin(), 8, 1.0f);
213+
std::fill_n(h_label.begin() + 8, 8, 0.0f);
214+
HostDeviceVector<float> predt;
215+
auto& h_predt = predt.HostVector();
216+
h_predt.resize(h_label.size());
217+
std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f);
218+
std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f);
219+
220+
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
221+
222+
predt.SetDevice(ctx->gpu_id);
223+
auto rank_idx =
224+
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
225+
226+
if (ctx->IsCPU()) {
227+
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
228+
p_cache);
229+
} else {
230+
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
231+
}
232+
233+
Context cpu_ctx;
234+
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
235+
ASSERT_EQ(n_rel[7], 8); // first group
236+
ASSERT_EQ(n_rel.back(), 0); // second group
237+
}
238+
}
239+
166240
TEST(LambdaRank, MAPStat) {
167241
Context ctx;
168242
TestMAPStat(&ctx);

tests/cpp/objective/test_lambdarank_obj.h

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,8 @@
1717
#include "../../../src/objective/lambdarank_obj.h" // for MAPStat
1818
#include "../helpers.h" // for EmptyDMatrix
1919

20-
namespace xgboost {
21-
namespace obj {
22-
inline void TestMAPStat(Context const* ctx) {
23-
auto p_fmat = EmptyDMatrix();
24-
MetaInfo& info = p_fmat->Info();
25-
ltr::LambdaRankParam param;
26-
param.UpdateAllowUnknown(Args{});
27-
28-
std::vector<float> h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
29-
info.labels.Reshape(h_data.size(), 1);
30-
info.labels.Data()->HostVector() = h_data;
31-
info.num_row_ = h_data.size();
32-
33-
HostDeviceVector<float> predt;
34-
auto& h_predt = predt.HostVector();
35-
h_predt.resize(h_data.size());
36-
std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f);
37-
38-
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
39-
40-
predt.SetDevice(ctx->gpu_id);
41-
auto rank_idx =
42-
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
43-
44-
if (ctx->IsCPU()) {
45-
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, p_cache);
46-
} else {
47-
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
48-
}
49-
50-
Context cpu_ctx;
51-
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
52-
auto acc = p_cache->Acc(&cpu_ctx);
53-
54-
ASSERT_EQ(n_rel[0], 1.0);
55-
ASSERT_EQ(acc[0], 1.0);
56-
57-
ASSERT_EQ(n_rel.back(), h_data.size() - 1.0);
58-
ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps);
59-
}
20+
namespace xgboost::obj {
21+
void TestMAPStat(Context const* ctx);
6022

6123
inline void TestNDCGJsonIO(Context const* ctx) {
6224
std::unique_ptr<xgboost::ObjFunction> obj{ObjFunction::Create("rank:ndcg", ctx)};
@@ -80,6 +42,5 @@ void TestMAPGPair(Context const* ctx);
8042
* \brief Initialize test data for make pair tests.
8143
*/
8244
void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector<float>* out_predt);
83-
} // namespace obj
84-
} // namespace xgboost
45+
} // namespace xgboost::obj
8546
#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_

0 commit comments

Comments
 (0)