Skip to content

Commit c2d06bd

Browse files
authored
Fix shap with vector intercept. (#11764)
1 parent a340adc commit c2d06bd

File tree

4 files changed

+70
-27
lines changed

4 files changed

+70
-27
lines changed

src/predictor/cpu_predictor.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,14 @@ struct DataToFeatVec {
257257

258258
template <typename EncAccessor>
259259
class SparsePageView : public DataToFeatVec<SparsePageView<EncAccessor>> {
260-
EncAccessor const &acc_;
260+
EncAccessor acc_;
261261
HostSparsePageView const view_;
262262

263263
public:
264264
bst_idx_t const base_rowid;
265265

266-
SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor const &acc)
267-
: acc_{acc}, view_{p}, base_rowid{base_rowid} {}
266+
SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor acc)
267+
: acc_{std::move(acc)}, view_{p}, base_rowid{base_rowid} {}
268268
[[nodiscard]] std::size_t Size() const { return view_.Size(); }
269269

270270
[[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const {
@@ -283,7 +283,7 @@ template <typename EncAccessor>
283283
class GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView<EncAccessor>> {
284284
private:
285285
GHistIndexMatrix const &page_;
286-
EncAccessor const &acc_;
286+
EncAccessor acc_;
287287
common::Span<FeatureType const> ft_;
288288

289289
std::vector<std::uint32_t> const &ptrs_;
@@ -295,10 +295,10 @@ class GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView<EncAccess
295295
bst_idx_t const base_rowid;
296296

297297
public:
298-
GHistIndexMatrixView(GHistIndexMatrix const &_page, EncAccessor const &acc,
298+
GHistIndexMatrixView(GHistIndexMatrix const &_page, EncAccessor acc,
299299
common::Span<FeatureType const> ft)
300300
: page_{_page},
301-
acc_{acc},
301+
acc_{std::move(acc)},
302302
ft_{ft},
303303
ptrs_{_page.cut.Ptrs()},
304304
mins_{_page.cut.MinValues()},
@@ -365,11 +365,11 @@ template <typename Adapter, typename EncAccessor>
365365
class AdapterView : public DataToFeatVec<AdapterView<Adapter, EncAccessor>> {
366366
Adapter const *adapter_;
367367
float missing_;
368-
EncAccessor const &acc_;
368+
EncAccessor acc_;
369369

370370
public:
371-
explicit AdapterView(Adapter const *adapter, float missing, EncAccessor const &acc)
372-
: adapter_{adapter}, missing_{missing}, acc_{acc} {}
371+
explicit AdapterView(Adapter const *adapter, float missing, EncAccessor acc)
372+
: adapter_{adapter}, missing_{missing}, acc_{std::move(acc)} {}
373373

374374
[[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const {
375375
auto const &batch = adapter_->Value();
@@ -408,7 +408,7 @@ struct EncAccessorPolicy {
408408
[[nodiscard]] auto MakeAccessor(Context const *ctx, enc::HostColumnsView new_enc,
409409
gbm::GBTreeModel const &model) {
410410
auto [acc, mapping] = MakeCatAccessor(ctx, new_enc, model.Cats());
411-
this->mapping_ = std::move(mapping);
411+
std::swap(mapping, this->mapping_);
412412
return acc;
413413
}
414414
};
@@ -923,7 +923,7 @@ class CPUPredictor : public Predictor {
923923
CHECK_NE(ncolumns, 0);
924924
auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device();
925925
auto base_margin = info.base_margin_.View(device);
926-
auto base_score = model.learner_model_param->BaseScore(device)(0);
926+
auto base_score = model.learner_model_param->BaseScore(device);
927927

928928
// parallel over local batch
929929
common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) {
@@ -962,7 +962,7 @@ class CPUPredictor : public Predictor {
962962
CHECK_EQ(base_margin.Shape(1), ngroup);
963963
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
964964
} else {
965-
p_contribs[ncolumns - 1] += base_score;
965+
p_contribs[ncolumns - 1] += base_score(gid);
966966
}
967967
}
968968
});

src/predictor/gpu_predictor.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ class GPUPredictor : public xgboost::Predictor {
11291129
// allocate space for (number of features + bias) times the number of rows
11301130
size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias
11311131
auto dim_size = contributions_columns * model.learner_model_param->num_output_group;
1132+
// Output shape: [n_samples, n_classes, n_features + 1]
11321133
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
11331134
out_contribs->Fill(0.0f);
11341135
auto phis = out_contribs->DeviceSpan();
@@ -1159,11 +1160,11 @@ class GPUPredictor : public xgboost::Predictor {
11591160
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
11601161

11611162
auto base_score = model.learner_model_param->BaseScore(ctx_);
1162-
dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
1163-
ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) {
1164-
phis[(idx + 1) * contributions_columns - 1] +=
1165-
margin.empty() ? base_score(0) : margin[idx];
1166-
});
1163+
bst_idx_t n_samples = p_fmat->Info().num_row_;
1164+
dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(std::size_t idx) {
1165+
auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup);
1166+
phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(gid) : margin[idx];
1167+
});
11671168
}
11681169

11691170
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
@@ -1219,14 +1220,13 @@ class GPUPredictor : public xgboost::Predictor {
12191220

12201221
auto base_score = model.learner_model_param->BaseScore(ctx_);
12211222
size_t n_features = model.learner_model_param->num_feature;
1222-
dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
1223-
ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) {
1224-
size_t group = idx % ngroup;
1225-
size_t row_idx = idx / ngroup;
1226-
phis[gpu_treeshap::IndexPhiInteractions(row_idx, ngroup, group, n_features,
1227-
n_features, n_features)] +=
1228-
margin.empty() ? base_score(0) : margin[idx];
1229-
});
1223+
bst_idx_t n_samples = p_fmat->Info().num_row_;
1224+
dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) {
1225+
auto [ridx, gidx] = linalg::UnravelIndex(idx, n_samples, ngroup);
1226+
phis[gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gidx, n_features, n_features,
1227+
n_features)] +=
1228+
margin.empty() ? base_score(gidx) : margin[idx];
1229+
});
12301230
}
12311231

12321232
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<float>* predictions,

tests/cpp/predictor/test_predictor.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ void ShapExternalMemoryTest::Run(Context const *ctx, bool is_qdm, bool is_intera
860860
.Classes(n_classes));
861861
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
862862
learner->SetParam("device", ctx->DeviceName());
863-
learner->SetParam("base_score", "0.5");
863+
learner->SetParam("base_score", "[0.5, 0.5, 0.5]");
864864
learner->SetParam("num_parallel_tree", "3");
865865
learner->SetParam("max_bin", std::to_string(max_bin));
866866
for (std::int32_t i = 0; i < 4; ++i) {
@@ -869,8 +869,10 @@ void ShapExternalMemoryTest::Run(Context const *ctx, bool is_qdm, bool is_intera
869869
Json model{Object{}};
870870
learner->SaveModel(&model);
871871
auto j_booster = model["learner"]["gradient_booster"]["model"];
872-
auto model_param = MakeMP(n_features, 0.0, n_classes, ctx->Device());
873872

873+
auto base_score = linalg::Tensor<float, 1>{{0.0, 0.0, 0.0}, {3}, ctx->Device()};
874+
LearnerModelParam model_param(n_features, std::move(base_score), n_classes, 1,
875+
MultiStrategy::kOneOutputPerTree);
874876
gbm::GBTreeModel gbtree{&model_param, ctx};
875877
gbtree.LoadModel(j_booster);
876878

tests/python-gpu/test_gpu_prediction.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,44 @@ def test_dtypes(self):
628628

629629
def test_base_margin_vs_base_score() -> None:
630630
run_base_margin_vs_base_score("cuda")
631+
632+
633+
@pytest.mark.skipif(**tm.no_sklearn())
634+
def test_shap_multiclass() -> None:
635+
from sklearn.datasets import make_classification
636+
637+
X, y = make_classification(n_classes=3, random_state=2025, n_informative=16)
638+
param = {
639+
"tree_method": "hist",
640+
"device": "cuda",
641+
"num_class": 3,
642+
"base_score": [1.0, 2.0, 3.0],
643+
}
644+
Xy = xgb.DMatrix(X, y)
645+
bst = xgb.train(param, Xy, 8)
646+
647+
d_shap = bst.predict(Xy, pred_contribs=True)
648+
d_margin = bst.predict(Xy, output_margin=True)
649+
650+
bst.set_param({"device": "cpu"})
651+
652+
h_shap = bst.predict(Xy, pred_contribs=True)
653+
h_margin = bst.predict(Xy, output_margin=True)
654+
655+
np.testing.assert_allclose(d_shap, h_shap, atol=1e-6)
656+
np.testing.assert_allclose(d_margin, h_margin, atol=1e-6)
657+
658+
# Compare base margin and base score
659+
margin = np.stack(
660+
[
661+
np.ones(X.shape[0]),
662+
np.full(X.shape[0], fill_value=2.0),
663+
np.full(X.shape[0], fill_value=3.0),
664+
],
665+
axis=1,
666+
)
667+
Xy = xgb.DMatrix(X, y, base_margin=margin)
668+
669+
bst.set_param({"device": "cuda"})
670+
d_shap = bst.predict(Xy, pred_contribs=True)
671+
np.testing.assert_allclose(d_shap, h_shap, atol=1e-6)

0 commit comments

Comments
 (0)