@@ -56,14 +56,18 @@ void TestNDCGGPair(Context const* ctx) {
5656 {0 , 2 , 4 },
5757 {2 .06611f , -2 .06611f , 0 .0f , 0 .0f },
5858 {2 .169331f , 2 .169331f , 0 .0f , 0 .0f });
59-
60- CheckRankingObjFunction (obj,
61- {0 , 0 .1f , 0 , 0 .1f },
62- {0 , 1 , 0 , 1 },
63- {2 .0f , 2 .0f },
64- {0 , 2 , 4 },
65- {2 .06611f , -2 .06611f , 2 .06611f , -2 .06611f },
66- {2 .169331f , 2 .169331f , 2 .169331f , 2 .169331f });
59+ }
60+ {
61+ std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create (" rank:ndcg" , ctx)};
62+ obj->Configure (Args{{" lambdarank_pair_method" , " topk" }});
63+ float weight_norm = 0.5 ; // n_groups / sum_weights
64+ std::vector<float > out_grad{2 .06611f , -2 .06611f , 2 .06611f , -2 .06611f };
65+ std::vector<float > out_hess{2 .169331f , 2 .169331f , 2 .169331f , 2 .169331f };
66+ auto norm = [=](auto v) { return v * weight_norm; };
67+ std::transform (out_grad.begin (), out_grad.end (), out_grad.begin (), norm);
68+ std::transform (out_hess.begin (), out_hess.end (), out_hess.begin (), norm);
69+ CheckRankingObjFunction (obj, {0 , 0 .1f , 0 , 0 .1f }, {0 , 1 , 0 , 1 }, {2 .0f , 2 .0f }, {0 , 2 , 4 },
70+ out_grad, out_hess);
6771 }
6872
6973 std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create (" rank:ndcg" , ctx)};
@@ -320,8 +324,7 @@ TEST(LambdaRank, MAPStat) {
320324
321325void TestMAPGPair (Context const * ctx) {
322326 std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create (" rank:map" , ctx)};
323- Args args;
324- obj->Configure (args);
327+ obj->Configure ({});
325328
326329 CheckConfigReload (obj, " rank:map" );
327330
@@ -332,14 +335,20 @@ void TestMAPGPair(Context const* ctx) {
332335 {0 , 2 , 4 }, // group
333336 {1 .2054923f , -1 .2054923f , 1 .2054923f , -1 .2054923f }, // out grad
334337 {1 .2657166f , 1 .2657166f , 1 .2657166f , 1 .2657166f });
338+
339+ obj.reset (xgboost::ObjFunction::Create (" rank:map" , ctx));
340+ obj->Configure ({});
341+
335342 // disable the second query group with 0 weight
336- CheckRankingObjFunction (obj, // obj
337- {0 , 0 .1f , 0 , 0 .1f }, // score
338- {0 , 1 , 0 , 1 }, // label
339- {2 .0f , 0 .0f }, // weight
340- {0 , 2 , 4 }, // group
341- {1 .2054923f , -1 .2054923f , .0f , .0f }, // out grad
342- {1 .2657166f , 1 .2657166f , .0f , .0f });
343+ auto w = 2 .0f ; // weight for the first group
344+ // weight norm is 1.0 (n_groups / sum_weights)
345+ CheckRankingObjFunction (obj, // obj
346+ {0 , 0 .1f , 0 , 0 .1f }, // score
347+ {0 , 1 , 0 , 1 }, // label
348+ {w, 0 .0f }, // weight
349+ {0 , 2 , 4 }, // group
350+ {1 .2054923f * w, -1 .2054923f * w, .0f , .0f }, // out grad
351+ {1 .2657166f * w, 1 .2657166f * w, .0f , .0f });
343352}
344353
345354TEST (LambdaRank, MAPGPair) {
0 commit comments