Skip to content

Commit 42d100d

Browse files
authored
Make sure metrics work with federated learning (#9037)
1 parent ef13dd3 commit 42d100d

11 files changed

+451
-152
lines changed

src/collective/aggregator.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/**
2+
* Copyright 2023 by XGBoost contributors
3+
*
4+
* Higher level functions built on top the Communicator API, taking care of behavioral differences
5+
* between row-split vs column-split distributed training, and horizontal vs vertical federated
6+
* learning.
7+
*/
8+
#pragma once
9+
#include <xgboost/data.h>
10+
11+
#include <string>
12+
#include <utility>
13+
#include <vector>
14+
15+
#include "communicator-inl.h"
16+
17+
namespace xgboost {
18+
namespace collective {
19+
20+
/**
21+
* @brief Apply the given function where the labels are.
22+
*
23+
* Normally all the workers have access to the labels, so the function is just applied locally. In
24+
* vertical federated learning, we assume labels are only available on worker 0, so the function is
25+
* applied there, with the results broadcast to other workers.
26+
*
27+
* @tparam Function The function used to calculate the results.
28+
* @tparam Args Arguments to the function.
29+
* @param info MetaInfo about the DMatrix.
30+
* @param buffer The buffer storing the results.
31+
* @param size The size of the buffer.
32+
* @param function The function used to calculate the results.
33+
* @param args Arguments to the function.
34+
*/
35+
template <typename Function, typename... Args>
36+
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function,
37+
Args&&... args) {
38+
if (info.IsVerticalFederated()) {
39+
// We assume labels are only available on worker 0, so the calculation is done there and result
40+
// broadcast to other workers.
41+
std::vector<char> message(1024);
42+
if (collective::GetRank() == 0) {
43+
try {
44+
std::forward<Function>(function)(std::forward<Args>(args)...);
45+
} catch (dmlc::Error& e) {
46+
strncpy(&message[0], e.what(), message.size());
47+
message.back() = '\0';
48+
}
49+
}
50+
collective::Broadcast(&message[0], message.size(), 0);
51+
if (strlen(&message[0]) == 0) {
52+
collective::Broadcast(buffer, size, 0);
53+
} else {
54+
LOG(FATAL) << &message[0];
55+
}
56+
} else {
57+
std::forward<Function>(function)(std::forward<Args>(args)...);
58+
}
59+
}
60+
61+
} // namespace collective
62+
} // namespace xgboost

src/learner.cc

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <utility> // for pair, as_const, move, swap
3535
#include <vector> // for vector
3636

37+
#include "collective/aggregator.h" // for ApplyWithLabels
3738
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
3839
#include "collective/communicator.h" // for Operation
3940
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
@@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
859860
}
860861

861862
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
862-
// Special handling for vertical federated learning.
863-
if (info.IsVerticalFederated()) {
864-
// We assume labels are only available on worker 0, so the estimation is calculated there
865-
// and broadcast to other workers.
866-
if (collective::GetRank() == 0) {
867-
UsePtr(obj_)->InitEstimation(info, base_score);
868-
collective::Broadcast(base_score->Data()->HostPointer(),
869-
sizeof(bst_float) * base_score->Size(), 0);
870-
} else {
871-
base_score->Reshape(1);
872-
collective::Broadcast(base_score->Data()->HostPointer(),
873-
sizeof(bst_float) * base_score->Size(), 0);
874-
}
875-
} else {
876-
UsePtr(obj_)->InitEstimation(info, base_score);
877-
}
863+
base_score->Reshape(1);
864+
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
865+
sizeof(bst_float) * base_score->Size(),
866+
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
878867
}
879868
};
880869

@@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
14861475
private:
14871476
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
14881477
HostDeviceVector<GradientPair>* out_gpair) {
1489-
// Special handling for vertical federated learning.
1490-
if (info.IsVerticalFederated()) {
1491-
// We assume labels are only available on worker 0, so the gradients are calculated there
1492-
// and broadcast to other workers.
1493-
if (collective::GetRank() == 0) {
1494-
obj_->GetGradient(preds, info, iteration, out_gpair);
1495-
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
1496-
0);
1497-
} else {
1498-
CHECK_EQ(info.labels.Size(), 0)
1499-
<< "In vertical federated learning, labels should only be on the first worker";
1500-
out_gpair->Resize(preds.Size());
1501-
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
1502-
0);
1503-
}
1504-
} else {
1505-
obj_->GetGradient(preds, info, iteration, out_gpair);
1506-
}
1478+
out_gpair->Resize(preds.Size());
1479+
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
1480+
out_gpair->Size() * sizeof(GradientPair),
1481+
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
15071482
}
15081483

15091484
/*! \brief random number transformation seed. */

src/metric/auc.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ class EvalAUC : public MetricNoCache {
270270
}
271271
// We use the global size to handle empty dataset.
272272
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
273-
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
273+
if (!info.IsVerticalFederated()) {
274+
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
275+
}
274276
if (meta[0] == 0) {
275277
// Empty across all workers, which is not supported.
276278
auc = std::numeric_limits<double>::quiet_NaN();

src/metric/metric_common.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <memory> // shared_ptr
1010
#include <string>
1111

12+
#include "../collective/aggregator.h"
13+
#include "../collective/communicator-inl.h"
1214
#include "../common/common.h"
1315
#include "xgboost/metric.h"
1416

@@ -20,7 +22,12 @@ class MetricNoCache : public Metric {
2022
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
2123

2224
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
23-
return this->Eval(predts, p_fmat->Info());
25+
double result{0.0};
26+
auto const& info = p_fmat->Info();
27+
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
28+
result = this->Eval(predts, info);
29+
});
30+
return result;
2431
}
2532
};
2633

src/metric/rank_metric.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
#include <algorithm> // for stable_sort, copy, fill_n, min, max
2929
#include <array> // for array
3030
#include <cmath> // for log, sqrt
31-
#include <cstddef> // for size_t, std
32-
#include <cstdint> // for uint32_t
3331
#include <functional> // for less, greater
32+
#include <limits> // for numeric_limits
3433
#include <map> // for operator!=, _Rb_tree_const_iterator
3534
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
3635
#include <numeric> // for accumulate
@@ -39,15 +38,11 @@
3938
#include <utility> // for pair, make_pair
4039
#include <vector> // for vector
4140

42-
#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
43-
#include "../collective/communicator.h" // for Operation
41+
#include "../collective/aggregator.h" // for ApplyWithLabels
4442
#include "../common/algorithm.h" // for ArgSort, Sort
4543
#include "../common/linalg_op.h" // for cbegin, cend
4644
#include "../common/math.h" // for CmpFirst
4745
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
48-
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
49-
#include "../common/threading_utils.h" // for ParallelFor
50-
#include "../common/transform_iterator.h" // for IndexTransformIter
5146
#include "dmlc/common.h" // for OMPException
5247
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
5348
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
@@ -59,7 +54,6 @@
5954
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
6055
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
6156
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
62-
#include "xgboost/span.h" // for Span, operator!=
6357
#include "xgboost/string_view.h" // for StringView
6458

6559
namespace {
@@ -385,15 +379,19 @@ class EvalRankWithCache : public Metric {
385379
}
386380

387381
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
382+
double result{0.0};
388383
auto const& info = p_fmat->Info();
389-
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
390-
if (p_cache->Param() != param_) {
391-
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
392-
}
393-
CHECK(p_cache->Param() == param_);
394-
CHECK_EQ(preds.Size(), info.labels.Size());
384+
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
385+
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
386+
if (p_cache->Param() != param_) {
387+
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
388+
}
389+
CHECK(p_cache->Param() == param_);
390+
CHECK_EQ(preds.Size(), info.labels.Size());
395391

396-
return this->Eval(preds, info, p_cache);
392+
result = this->Eval(preds, info, p_cache);
393+
});
394+
return result;
397395
}
398396

399397
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,

tests/cpp/helpers.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
189189
info.weights_.HostVector() = weights;
190190
info.group_ptr_ = groups;
191191
info.data_split_mode = data_split_mode;
192-
192+
if (info.IsVerticalFederated() && xgboost::collective::GetRank() != 0) {
193+
info.labels.Reshape(0);
194+
}
193195
return metric->Evaluate(preds, p_fmat);
194196
}
195197

tests/cpp/metric/test_survival_metric.cu

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2,109 +2,13 @@
22
* Copyright (c) by Contributors 2020
33
*/
44
#include <gtest/gtest.h>
5-
#include <cmath>
5+
#include "test_survival_metric.h"
66
#include "xgboost/metric.h"
7-
#include "../helpers.h"
8-
#include "../../../src/common/survival_util.h"
97

108
/** Tests for Survival metrics that should run both on CPU and GPU **/
119

1210
namespace xgboost {
1311
namespace common {
14-
namespace {
15-
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
16-
auto ctx = CreateEmptyGenericParam(device);
17-
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
18-
metric->Configure(Args{});
19-
20-
HostDeviceVector<float> predts;
21-
auto p_fmat = EmptyDMatrix();
22-
MetaInfo& info = p_fmat->Info();
23-
auto &h_predts = predts.HostVector();
24-
25-
SimpleLCG lcg;
26-
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
27-
28-
size_t n_samples = 2048;
29-
h_predts.resize(n_samples);
30-
31-
for (size_t i = 0; i < n_samples; ++i) {
32-
h_predts[i] = dist(&lcg);
33-
}
34-
35-
auto &h_upper = info.labels_upper_bound_.HostVector();
36-
auto &h_lower = info.labels_lower_bound_.HostVector();
37-
h_lower.resize(n_samples);
38-
h_upper.resize(n_samples);
39-
for (size_t i = 0; i < n_samples; ++i) {
40-
h_lower[i] = 1;
41-
h_upper[i] = 10;
42-
}
43-
44-
auto result = metric->Evaluate(predts, p_fmat);
45-
for (size_t i = 0; i < 8; ++i) {
46-
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
47-
}
48-
}
49-
50-
void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
51-
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
52-
53-
/**
54-
* Test aggregate output from the AFT metric over a small test data set.
55-
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
56-
**/
57-
auto p_fmat = EmptyDMatrix();
58-
MetaInfo& info = p_fmat->Info();
59-
info.num_row_ = 4;
60-
info.labels_lower_bound_.HostVector()
61-
= { 100.0f, 0.0f, 60.0f, 16.0f };
62-
info.labels_upper_bound_.HostVector()
63-
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
64-
info.weights_.HostVector() = std::vector<bst_float>();
65-
info.data_split_mode = data_split_mode;
66-
HostDeviceVector<bst_float> preds(4, std::log(64));
67-
68-
struct TestCase {
69-
std::string dist_type;
70-
bst_float reference_value;
71-
};
72-
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
73-
{"extreme", 2.0706f} }) {
74-
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
75-
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
76-
{"aft_loss_distribution_scale", "1.0"} });
77-
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
78-
}
79-
}
80-
81-
void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
82-
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
83-
84-
auto p_fmat = EmptyDMatrix();
85-
MetaInfo& info = p_fmat->Info();
86-
info.num_row_ = 4;
87-
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
88-
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
89-
info.weights_.HostVector() = std::vector<bst_float>();
90-
info.data_split_mode = data_split_mode;
91-
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
92-
93-
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
94-
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
95-
info.labels_lower_bound_.HostVector()[2] = 70.0f;
96-
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
97-
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
98-
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
99-
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
100-
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
101-
info.labels_lower_bound_.HostVector()[0] = 70.0f;
102-
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
103-
104-
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
105-
}
106-
} // anonymous namespace
107-
10812
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); }
10913

11014
TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) {
@@ -140,6 +44,5 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) {
14044

14145
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
14246
}
143-
14447
} // namespace common
14548
} // namespace xgboost

0 commit comments

Comments
 (0)