Skip to content

Commit ca7230f

Browse files
authored
Fix ranking tests with weight norm. (#11800)
The test might fail if the objective cache is renewed. This PR makes sure the objective and its cache is recreated for each test case, and multiple the normalization factor into the expected results.
1 parent bd54840 commit ca7230f

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

tests/cpp/objective/test_lambdarank_obj.cc

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,18 @@ void TestNDCGGPair(Context const* ctx) {
5656
{0, 2, 4},
5757
{2.06611f, -2.06611f, 0.0f, 0.0f},
5858
{2.169331f, 2.169331f, 0.0f, 0.0f});
59-
60-
CheckRankingObjFunction(obj,
61-
{0, 0.1f, 0, 0.1f},
62-
{0, 1, 0, 1},
63-
{2.0f, 2.0f},
64-
{0, 2, 4},
65-
{2.06611f, -2.06611f, 2.06611f, -2.06611f},
66-
{2.169331f, 2.169331f, 2.169331f, 2.169331f});
59+
}
60+
{
61+
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
62+
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
63+
float weight_norm = 0.5; // n_groups / sum_weights
64+
std::vector<float> out_grad{2.06611f, -2.06611f, 2.06611f, -2.06611f};
65+
std::vector<float> out_hess{2.169331f, 2.169331f, 2.169331f, 2.169331f};
66+
auto norm = [=](auto v) { return v * weight_norm; };
67+
std::transform(out_grad.begin(), out_grad.end(), out_grad.begin(), norm);
68+
std::transform(out_hess.begin(), out_hess.end(), out_hess.begin(), norm);
69+
CheckRankingObjFunction(obj, {0, 0.1f, 0, 0.1f}, {0, 1, 0, 1}, {2.0f, 2.0f}, {0, 2, 4},
70+
out_grad, out_hess);
6771
}
6872

6973
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
@@ -320,8 +324,7 @@ TEST(LambdaRank, MAPStat) {
320324

321325
void TestMAPGPair(Context const* ctx) {
322326
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", ctx)};
323-
Args args;
324-
obj->Configure(args);
327+
obj->Configure({});
325328

326329
CheckConfigReload(obj, "rank:map");
327330

@@ -332,14 +335,20 @@ void TestMAPGPair(Context const* ctx) {
332335
{0, 2, 4}, // group
333336
{1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad
334337
{1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f});
338+
339+
obj.reset(xgboost::ObjFunction::Create("rank:map", ctx));
340+
obj->Configure({});
341+
335342
// disable the second query group with 0 weight
336-
CheckRankingObjFunction(obj, // obj
337-
{0, 0.1f, 0, 0.1f}, // score
338-
{0, 1, 0, 1}, // label
339-
{2.0f, 0.0f}, // weight
340-
{0, 2, 4}, // group
341-
{1.2054923f, -1.2054923f, .0f, .0f}, // out grad
342-
{1.2657166f, 1.2657166f, .0f, .0f});
343+
auto w = 2.0f; // weight for the first group
344+
// weight norm is 1.0 (n_groups / sum_weights)
345+
CheckRankingObjFunction(obj, // obj
346+
{0, 0.1f, 0, 0.1f}, // score
347+
{0, 1, 0, 1}, // label
348+
{w, 0.0f}, // weight
349+
{0, 2, 4}, // group
350+
{1.2054923f * w, -1.2054923f * w, .0f, .0f}, // out grad
351+
{1.2657166f * w, 1.2657166f * w, .0f, .0f});
343352
}
344353

345354
TEST(LambdaRank, MAPGPair) {

0 commit comments

Comments
 (0)