Skip to content

Commit ad2e715

Browse files
authored
Small cleanup to the proxy dmatrix. (#11594)
- Remove the auto-dispatch for CUDA methods. - Consistent naming. - Cleanup C API.
1 parent 86a9809 commit ad2e715

File tree

14 files changed

+103
-113
lines changed

14 files changed

+103
-113
lines changed

src/c_api/c_api.cc

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -422,52 +422,45 @@ XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
422422
API_END();
423423
}
424424

425-
XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
426-
char const *c_interface_str) {
427-
API_BEGIN();
428-
CHECK_HANDLE();
429-
xgboost_CHECK_C_ARG_PTR(c_interface_str);
425+
namespace {
426+
[[nodiscard]] xgboost::data::DMatrixProxy *GetDMatrixProxy(DMatrixHandle handle) {
430427
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
431428
CHECK(p_m);
432429
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
433430
CHECK(m) << "Current DMatrix type does not support set data.";
434-
m->SetCUDAArray(c_interface_str);
431+
return m;
432+
}
433+
} // namespace
434+
435+
XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle, char const *data) {
436+
API_BEGIN();
437+
CHECK_HANDLE();
438+
xgboost_CHECK_C_ARG_PTR(data);
439+
GetDMatrixProxy(handle)->SetCudaArray(data);
435440
API_END();
436441
}
437442

438-
XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, char const *c_interface_str) {
443+
XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, char const *data) {
439444
API_BEGIN();
440445
CHECK_HANDLE();
441-
xgboost_CHECK_C_ARG_PTR(c_interface_str);
442-
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
443-
CHECK(p_m);
444-
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
445-
CHECK(m) << "Current DMatrix type does not support set data.";
446-
m->SetCUDAArray(c_interface_str);
446+
xgboost_CHECK_C_ARG_PTR(data);
447+
GetDMatrixProxy(handle)->SetCudaColumnar(data);
447448
API_END();
448449
}
449450

450-
XGB_DLL int XGProxyDMatrixSetDataColumnar(DMatrixHandle handle, char const *c_interface_str) {
451+
XGB_DLL int XGProxyDMatrixSetDataColumnar(DMatrixHandle handle, char const *data) {
451452
API_BEGIN();
452453
CHECK_HANDLE();
453-
xgboost_CHECK_C_ARG_PTR(c_interface_str);
454-
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
455-
CHECK(p_m);
456-
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
457-
CHECK(m) << "Current DMatrix type does not support set data.";
458-
m->SetColumnarData(c_interface_str);
454+
xgboost_CHECK_C_ARG_PTR(data);
455+
GetDMatrixProxy(handle)->SetColumnar(data);
459456
API_END();
460457
}
461458

462-
XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle, char const *c_interface_str) {
459+
XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle, char const *data) {
463460
API_BEGIN();
464461
CHECK_HANDLE();
465-
xgboost_CHECK_C_ARG_PTR(c_interface_str);
466-
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
467-
CHECK(p_m);
468-
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
469-
CHECK(m) << "Current DMatrix type does not support set data.";
470-
m->SetArrayData(c_interface_str);
462+
xgboost_CHECK_C_ARG_PTR(data);
463+
GetDMatrixProxy(handle)->SetArray(data);
471464
API_END();
472465
}
473466

@@ -478,11 +471,7 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, c
478471
xgboost_CHECK_C_ARG_PTR(indptr);
479472
xgboost_CHECK_C_ARG_PTR(indices);
480473
xgboost_CHECK_C_ARG_PTR(data);
481-
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
482-
CHECK(p_m);
483-
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
484-
CHECK(m) << "Current DMatrix type does not support set data.";
485-
m->SetCSRData(indptr, indices, data, ncol, true);
474+
GetDMatrixProxy(handle)->SetCsr(indptr, indices, data, ncol, true);
486475
API_END();
487476
}
488477

@@ -1402,7 +1391,7 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
14021391
*out_shape = dmlc::BeginPtr(shape);
14031392
}
14041393

1405-
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_interface,
1394+
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *data,
14061395
char const *c_json_config, DMatrixHandle m,
14071396
xgboost::bst_ulong const **out_shape,
14081397
xgboost::bst_ulong *out_dim, const float **out_result) {
@@ -1416,8 +1405,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_in
14161405
}
14171406
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14181407
CHECK(proxy) << "Invalid input type for inplace predict.";
1419-
xgboost_CHECK_C_ARG_PTR(array_interface);
1420-
proxy->SetArrayData(array_interface);
1408+
xgboost_CHECK_C_ARG_PTR(data);
1409+
proxy->SetArray(data);
14211410
auto *learner = static_cast<xgboost::Learner *>(handle);
14221411
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14231412
API_END();
@@ -1438,7 +1427,7 @@ XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array
14381427
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14391428
CHECK(proxy) << "Invalid input type for inplace predict.";
14401429
xgboost_CHECK_C_ARG_PTR(array_interface);
1441-
proxy->SetColumnarData(array_interface);
1430+
proxy->SetColumnar(array_interface);
14421431
auto *learner = static_cast<xgboost::Learner *>(handle);
14431432
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14441433
API_END();
@@ -1460,7 +1449,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
14601449
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14611450
CHECK(proxy) << "Invalid input type for inplace predict.";
14621451
xgboost_CHECK_C_ARG_PTR(indptr);
1463-
proxy->SetCSRData(indptr, indices, data, cols, true);
1452+
proxy->SetCsr(indptr, indices, data, cols, true);
14641453
auto *learner = static_cast<xgboost::Learner *>(handle);
14651454
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14661455
API_END();

src/c_api/c_api.cu

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,24 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
138138
API_END();
139139
}
140140

141-
int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
142-
char const *c_json_config, std::shared_ptr<DMatrix> p_m,
143-
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
144-
const float **out_result) {
141+
template <bool is_columnar>
142+
int InplacePreidctCUDA(BoosterHandle handle, char const *data, char const *c_json_config,
143+
std::shared_ptr<DMatrix> p_m, xgboost::bst_ulong const **out_shape,
144+
xgboost::bst_ulong *out_dim, const float **out_result) {
145145
API_BEGIN();
146146
CHECK_HANDLE();
147147
if (!p_m) {
148148
p_m.reset(new data::DMatrixProxy);
149149
}
150150
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
151151
CHECK(proxy) << "Invalid input type for inplace predict.";
152+
xgboost_CHECK_C_ARG_PTR(data);
152153

153-
proxy->SetCUDAArray(c_array_interface);
154+
if constexpr (is_columnar) {
155+
proxy->SetCudaColumnar(data);
156+
} else {
157+
proxy->SetCudaArray(data);
158+
}
154159

155160
auto config = Json::Load(StringView{c_json_config});
156161
auto *learner = static_cast<Learner *>(handle);
@@ -184,7 +189,7 @@ int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
184189
API_END();
185190
}
186191

187-
XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c_json_strs,
192+
XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *data,
188193
char const *c_json_config, DMatrixHandle m,
189194
xgboost::bst_ulong const **out_shape,
190195
xgboost::bst_ulong *out_dim,
@@ -194,11 +199,10 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c
194199
if (m) {
195200
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
196201
}
197-
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
198-
out_result);
202+
return InplacePreidctCUDA<true>(handle, data, c_json_config, p_m, out_shape, out_dim, out_result);
199203
}
200204

201-
XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_json_strs,
205+
XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *data,
202206
char const *c_json_config, DMatrixHandle m,
203207
xgboost::bst_ulong const **out_shape,
204208
xgboost::bst_ulong *out_dim, const float **out_result) {
@@ -207,6 +211,6 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_js
207211
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
208212
}
209213
xgboost_CHECK_C_ARG_PTR(out_result);
210-
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
211-
out_result);
214+
return InplacePreidctCUDA<false>(handle, data, c_json_config, p_m, out_shape, out_dim,
215+
out_result);
212216
}

src/data/proxy_dmatrix.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2021-2024, XGBoost Contributors
2+
* Copyright 2021-2025, XGBoost Contributors
33
*/
44

55
#include "proxy_dmatrix.h"
@@ -16,23 +16,23 @@
1616
#endif
1717

1818
namespace xgboost::data {
19-
void DMatrixProxy::SetColumnarData(StringView interface_str) {
20-
std::shared_ptr<ColumnarAdapter> adapter{new ColumnarAdapter{interface_str}};
19+
void DMatrixProxy::SetColumnar(StringView data) {
20+
std::shared_ptr<ColumnarAdapter> adapter{new ColumnarAdapter{data}};
2121
this->batch_ = adapter;
2222
this->Info().num_col_ = adapter->NumColumns();
2323
this->Info().num_row_ = adapter->NumRows();
2424
this->ctx_.Init(Args{{"device", "cpu"}});
2525
}
2626

27-
void DMatrixProxy::SetArrayData(StringView interface_str) {
28-
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
27+
void DMatrixProxy::SetArray(StringView data) {
28+
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{data}};
2929
this->batch_ = adapter;
3030
this->Info().num_col_ = adapter->NumColumns();
3131
this->Info().num_row_ = adapter->NumRows();
3232
this->ctx_.Init(Args{{"device", "cpu"}});
3333
}
3434

35-
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, char const *c_values,
35+
void DMatrixProxy::SetCsr(char const *c_indptr, char const *c_indices, char const *c_values,
3636
bst_feature_t n_features, bool on_host) {
3737
CHECK(on_host) << "Not implemented on device.";
3838
std::shared_ptr<CSRArrayAdapter> adapter{new CSRArrayAdapter(
@@ -43,6 +43,11 @@ void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, char
4343
this->ctx_.Init(Args{{"device", "cpu"}});
4444
}
4545

46+
#if !defined(XGBOOST_USE_CUDA)
47+
void DMatrixProxy::SetCudaArray(StringView) { common::AssertGPUSupport(); }
48+
void DMatrixProxy::SetCudaColumnar(StringView) { common::AssertGPUSupport(); }
49+
#endif // !defined(XGBOOST_USE_CUDA)
50+
4651
namespace cuda_impl {
4752
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
4853
std::shared_ptr<DMatrixProxy> proxy, float missing);

src/data/proxy_dmatrix.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include "proxy_dmatrix.h"
88

99
namespace xgboost::data {
10-
void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
11-
auto adapter{std::make_shared<CudfAdapter>(interface_str)};
10+
void DMatrixProxy::SetCudaColumnar(StringView data) {
11+
auto adapter{std::make_shared<CudfAdapter>(data)};
1212
this->batch_ = adapter;
1313
this->Info().num_col_ = adapter->NumColumns();
1414
this->Info().num_row_ = adapter->NumRows();
@@ -21,8 +21,8 @@ void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
2121
ctx_ = ctx_.MakeCUDA(adapter->Device().ordinal);
2222
}
2323

24-
void DMatrixProxy::FromCudaArray(StringView interface_str) {
25-
auto adapter(std::make_shared<CupyAdapter>(StringView{interface_str}));
24+
void DMatrixProxy::SetCudaArray(StringView data) {
25+
auto adapter(std::make_shared<CupyAdapter>(StringView{data}));
2626
this->batch_ = adapter;
2727
this->Info().num_col_ = adapter->NumColumns();
2828
this->Info().num_row_ = adapter->NumRows();

src/data/proxy_dmatrix.h

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,51 +65,39 @@ class DataIterProxy {
6565
};
6666

6767
/**
68-
* @brief A proxy of DMatrix used by external iterator.
68+
* @brief A proxy of DMatrix used by the external iterator.
6969
*/
7070
class DMatrixProxy : public DMatrix {
7171
MetaInfo info_;
7272
std::any batch_;
7373
Context ctx_;
7474

75-
#if defined(XGBOOST_USE_CUDA)
76-
void FromCudaColumnar(StringView interface_str);
77-
void FromCudaArray(StringView interface_str);
78-
#endif // defined(XGBOOST_USE_CUDA)
79-
8075
public:
8176
DeviceOrd Device() const { return ctx_.Device(); }
8277

83-
void SetCUDAArray(char const* c_interface) {
84-
common::AssertGPUSupport();
85-
CHECK(c_interface);
86-
#if defined(XGBOOST_USE_CUDA)
87-
StringView interface_str{c_interface};
88-
Json json_array_interface = Json::Load(interface_str);
89-
if (IsA<Array>(json_array_interface)) {
90-
this->FromCudaColumnar(interface_str);
91-
} else {
92-
this->FromCudaArray(interface_str);
93-
}
94-
#endif // defined(XGBOOST_USE_CUDA)
95-
}
96-
97-
void SetColumnarData(StringView interface_str);
98-
99-
void SetArrayData(StringView interface_str);
100-
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
78+
/**
79+
* Device setters
80+
*/
81+
void SetCudaColumnar(StringView data);
82+
void SetCudaArray(StringView data);
83+
/**
84+
* Host setters
85+
*/
86+
void SetColumnar(StringView data);
87+
void SetArray(StringView data);
88+
void SetCsr(char const* c_indptr, char const* c_indices, char const* c_values,
10189
bst_feature_t n_features, bool on_host);
10290

10391
MetaInfo& Info() override { return info_; }
10492
MetaInfo const& Info() const override { return info_; }
10593
Context const* Ctx() const override { return &ctx_; }
10694

107-
bool EllpackExists() const override { return false; }
108-
bool GHistIndexExists() const override { return false; }
109-
bool SparsePageExists() const override { return false; }
95+
[[nodiscard]] bool EllpackExists() const override { return false; }
96+
[[nodiscard]] bool GHistIndexExists() const override { return false; }
97+
[[nodiscard]] bool SparsePageExists() const override { return false; }
11098

11199
template <typename Page>
112-
BatchSet<Page> NoBatch() {
100+
static BatchSet<Page> NoBatch() {
113101
LOG(FATAL) << "Proxy DMatrix cannot return data batch.";
114102
return BatchSet<Page>(BatchIterator<Page>(nullptr));
115103
}

tests/cpp/data/test_proxy_dmatrix.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
/**
2-
* Copyright 2021-2023, XGBoost contributors
2+
* Copyright 2021-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55

6-
#include "../../../src/data/adapter.h"
6+
#include <cstddef> // for size_t
7+
#include <vector> // for vector
8+
79
#include "../../../src/data/proxy_dmatrix.h"
810
#include "../helpers.h"
11+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
912

1013
namespace xgboost::data {
1114
TEST(ProxyDMatrix, HostData) {
1215
DMatrixProxy proxy;
13-
size_t constexpr kRows = 100, kCols = 10;
16+
std::size_t constexpr kRows = 100, kCols = 10;
1417
std::vector<HostDeviceVector<float>> label_storage(1);
1518

1619
HostDeviceVector<float> storage;
1720
auto data =
1821
RandomDataGenerator(kRows, kCols, 0.5).Device(FstCU()).GenerateArrayInterface(&storage);
1922

20-
proxy.SetArrayData(data.c_str());
23+
proxy.SetArray(data.c_str());
2124

2225
auto n_samples = HostAdapterDispatch(&proxy, [](auto const &value) { return value.Size(); });
2326
ASSERT_EQ(n_samples, kRows);

tests/cpp/data/test_proxy_dmatrix.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
/**
2-
* Copyright 2020-2023 XGBoost contributors
2+
* Copyright 2020-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/host_device_vector.h>
66

7-
#include <any> // for any_cast
8-
#include <memory>
7+
#include <any> // for any_cast
8+
#include <memory> // for shared_ptr
9+
#include <vector> // for vector
910

1011
#include "../../../src/data/device_adapter.cuh"
1112
#include "../../../src/data/proxy_dmatrix.h"
1213
#include "../helpers.h"
14+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
1315

1416
namespace xgboost::data {
1517
TEST(ProxyDMatrix, DeviceData) {
@@ -23,7 +25,7 @@ TEST(ProxyDMatrix, DeviceData) {
2325
.GenerateColumnarArrayInterface(&label_storage);
2426

2527
DMatrixProxy proxy;
26-
proxy.SetCUDAArray(data.c_str());
28+
proxy.SetCudaArray(data.c_str());
2729
proxy.SetInfo("label", labels.c_str());
2830

2931
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CupyAdapter>));
@@ -35,7 +37,7 @@ TEST(ProxyDMatrix, DeviceData) {
3537
data = RandomDataGenerator(kRows, kCols, 0)
3638
.Device(FstCU())
3739
.GenerateColumnarArrayInterface(&columnar_storage);
38-
proxy.SetCUDAArray(data.c_str());
40+
proxy.SetCudaColumnar(data.c_str());
3941
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CudfAdapter>));
4042
ASSERT_EQ(std::any_cast<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumRows(), kRows);
4143
ASSERT_EQ(std::any_cast<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumColumns(), kCols);

0 commit comments

Comments
 (0)