@@ -422,52 +422,45 @@ XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
422422 API_END ();
423423}
424424
425- XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface (DMatrixHandle handle,
426- char const *c_interface_str) {
427- API_BEGIN ();
428- CHECK_HANDLE ();
429- xgboost_CHECK_C_ARG_PTR (c_interface_str);
425+ namespace {
426+ [[nodiscard]] xgboost::data::DMatrixProxy *GetDMatrixProxy (DMatrixHandle handle) {
430427 auto p_m = static_cast <std::shared_ptr<xgboost::DMatrix> *>(handle);
431428 CHECK (p_m);
432429 auto m = static_cast <xgboost::data::DMatrixProxy *>(p_m->get ());
433430 CHECK (m) << " Current DMatrix type does not support set data." ;
434- m->SetCUDAArray (c_interface_str);
431+ return m;
432+ }
433+ } // namespace
434+
435+ XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface (DMatrixHandle handle, char const *data) {
436+ API_BEGIN ();
437+ CHECK_HANDLE ();
438+ xgboost_CHECK_C_ARG_PTR (data);
439+ GetDMatrixProxy (handle)->SetCudaArray (data);
435440 API_END ();
436441}
437442
438- XGB_DLL int XGProxyDMatrixSetDataCudaColumnar (DMatrixHandle handle, char const *c_interface_str ) {
443+ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar (DMatrixHandle handle, char const *data ) {
439444 API_BEGIN ();
440445 CHECK_HANDLE ();
441- xgboost_CHECK_C_ARG_PTR (c_interface_str);
442- auto p_m = static_cast <std::shared_ptr<xgboost::DMatrix> *>(handle);
443- CHECK (p_m);
444- auto m = static_cast <xgboost::data::DMatrixProxy *>(p_m->get ());
445- CHECK (m) << " Current DMatrix type does not support set data." ;
446- m->SetCUDAArray (c_interface_str);
446+ xgboost_CHECK_C_ARG_PTR (data);
447+ GetDMatrixProxy (handle)->SetCudaColumnar (data);
447448 API_END ();
448449}
449450
450- XGB_DLL int XGProxyDMatrixSetDataColumnar (DMatrixHandle handle, char const *c_interface_str ) {
451+ XGB_DLL int XGProxyDMatrixSetDataColumnar (DMatrixHandle handle, char const *data ) {
451452 API_BEGIN ();
452453 CHECK_HANDLE ();
453- xgboost_CHECK_C_ARG_PTR (c_interface_str);
454- auto p_m = static_cast <std::shared_ptr<xgboost::DMatrix> *>(handle);
455- CHECK (p_m);
456- auto m = static_cast <xgboost::data::DMatrixProxy *>(p_m->get ());
457- CHECK (m) << " Current DMatrix type does not support set data." ;
458- m->SetColumnarData (c_interface_str);
454+ xgboost_CHECK_C_ARG_PTR (data);
455+ GetDMatrixProxy (handle)->SetColumnar (data);
459456 API_END ();
460457}
461458
462- XGB_DLL int XGProxyDMatrixSetDataDense (DMatrixHandle handle, char const *c_interface_str ) {
459+ XGB_DLL int XGProxyDMatrixSetDataDense (DMatrixHandle handle, char const *data ) {
463460 API_BEGIN ();
464461 CHECK_HANDLE ();
465- xgboost_CHECK_C_ARG_PTR (c_interface_str);
466- auto p_m = static_cast <std::shared_ptr<xgboost::DMatrix> *>(handle);
467- CHECK (p_m);
468- auto m = static_cast <xgboost::data::DMatrixProxy *>(p_m->get ());
469- CHECK (m) << " Current DMatrix type does not support set data." ;
470- m->SetArrayData (c_interface_str);
462+ xgboost_CHECK_C_ARG_PTR (data);
463+ GetDMatrixProxy (handle)->SetArray (data);
471464 API_END ();
472465}
473466
@@ -478,11 +471,7 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, c
478471 xgboost_CHECK_C_ARG_PTR (indptr);
479472 xgboost_CHECK_C_ARG_PTR (indices);
480473 xgboost_CHECK_C_ARG_PTR (data);
481- auto p_m = static_cast <std::shared_ptr<xgboost::DMatrix> *>(handle);
482- CHECK (p_m);
483- auto m = static_cast <xgboost::data::DMatrixProxy *>(p_m->get ());
484- CHECK (m) << " Current DMatrix type does not support set data." ;
485- m->SetCSRData (indptr, indices, data, ncol, true );
474+ GetDMatrixProxy (handle)->SetCsr (indptr, indices, data, ncol, true );
486475 API_END ();
487476}
488477
@@ -1402,7 +1391,7 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
14021391 *out_shape = dmlc::BeginPtr (shape);
14031392}
14041393
1405- XGB_DLL int XGBoosterPredictFromDense (BoosterHandle handle, char const *array_interface ,
1394+ XGB_DLL int XGBoosterPredictFromDense (BoosterHandle handle, char const *data ,
14061395 char const *c_json_config, DMatrixHandle m,
14071396 xgboost::bst_ulong const **out_shape,
14081397 xgboost::bst_ulong *out_dim, const float **out_result) {
@@ -1416,8 +1405,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_in
14161405 }
14171406 auto proxy = dynamic_cast <data::DMatrixProxy *>(p_m.get ());
14181407 CHECK (proxy) << " Invalid input type for inplace predict." ;
1419- xgboost_CHECK_C_ARG_PTR (array_interface );
1420- proxy->SetArrayData (array_interface );
1408+ xgboost_CHECK_C_ARG_PTR (data );
1409+ proxy->SetArray (data );
14211410 auto *learner = static_cast <xgboost::Learner *>(handle);
14221411 InplacePredictImpl (p_m, c_json_config, learner, out_shape, out_dim, out_result);
14231412 API_END ();
@@ -1438,7 +1427,7 @@ XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array
14381427 auto proxy = dynamic_cast <data::DMatrixProxy *>(p_m.get ());
14391428 CHECK (proxy) << " Invalid input type for inplace predict." ;
14401429 xgboost_CHECK_C_ARG_PTR (array_interface);
1441- proxy->SetColumnarData (array_interface);
1430+ proxy->SetColumnar (array_interface);
14421431 auto *learner = static_cast <xgboost::Learner *>(handle);
14431432 InplacePredictImpl (p_m, c_json_config, learner, out_shape, out_dim, out_result);
14441433 API_END ();
@@ -1460,7 +1449,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
14601449 auto proxy = dynamic_cast <data::DMatrixProxy *>(p_m.get ());
14611450 CHECK (proxy) << " Invalid input type for inplace predict." ;
14621451 xgboost_CHECK_C_ARG_PTR (indptr);
1463- proxy->SetCSRData (indptr, indices, data, cols, true );
1452+ proxy->SetCsr (indptr, indices, data, cols, true );
14641453 auto *learner = static_cast <xgboost::Learner *>(handle);
14651454 InplacePredictImpl (p_m, c_json_config, learner, out_shape, out_dim, out_result);
14661455 API_END ();
0 commit comments