88#ifndef XGBOOST_LEARNER_H_
99#define XGBOOST_LEARNER_H_
1010
11- #include < dmlc/io.h> // for Serializable
12- #include < xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
13- #include < xgboost/context.h> // for Context
14- #include < xgboost/linalg.h> // for Vector, VectorView
15- #include < xgboost/metric.h> // for Metric
16- #include < xgboost/model.h> // for Configurable, Model
17- #include < xgboost/span.h> // for Span
18- #include < xgboost/task.h> // for ObjInfo
11+ #include < dmlc/io.h> // for Serializable
12+ #include < xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, ..
13+ #include < xgboost/context.h> // for Context
14+ #include < xgboost/gradient.h> // for GradientContainer
15+ #include < xgboost/linalg.h> // for Vector, VectorView
16+ #include < xgboost/metric.h> // for Metric
17+ #include < xgboost/model.h> // for Configurable, Model
18+ #include < xgboost/span.h> // for Span
19+ #include < xgboost/task.h> // for ObjInfo
1920
20- #include < algorithm> // for max
21- #include < cstdint> // for int32_t, uint32_t, uint8_t
22- #include < map> // for map
23- #include < memory> // for shared_ptr, unique_ptr
24- #include < string> // for string
25- #include < utility> // for move
26- #include < vector> // for vector
21+ #include < algorithm> // for max
22+ #include < cstdint> // for int32_t, uint32_t, uint8_t
23+ #include < map> // for map
24+ #include < memory> // for shared_ptr, unique_ptr
25+ #include < string> // for string
26+ #include < utility> // for move
27+ #include < vector> // for vector
2728
2829namespace xgboost {
2930class FeatureMap ;
@@ -47,25 +48,24 @@ enum class PredictionType : std::uint8_t { // NOLINT
4748 kLeaf = 6
4849};
4950
50- /* !
51- * \ brief Learner class that does training and prediction.
51+ /* *
52+ * @ brief Learner class that does training and prediction.
5253 * This is the user facing module of xgboost training.
5354 * The Load/Save function corresponds to the model used in python/R.
54- * \ code
55+ * @ code
5556 *
56- * std::unique_ptr<Learner> learner(new Learner::Create(cache_mats)) ;
57- * learner. Configure(configs);
57+ * std::unique_ptr<Learner> learner{ Learner::Create(cache_mats)} ;
58+ * learner-> Configure(configs);
5859 *
5960 * for (int iter = 0; iter < max_iter; ++iter) {
6061 * learner->UpdateOneIter(iter, train_mat);
6162 * LOG(INFO) << learner->EvalOneIter(iter, data_sets, data_names);
6263 * }
6364 *
64- * \ endcode
65+ * @ endcode
6566 */
6667class Learner : public Model , public Configurable , public dmlc ::Serializable {
6768 public:
68- /* ! \brief virtual destructor */
6969 ~Learner () override ;
7070 /* !
7171 * \brief Configure Learner based on set parameters.
@@ -88,7 +88,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
8888 * @param in_gpair The input gradient statistics.
8989 */
9090 virtual void BoostOneIter (std::int32_t iter, std::shared_ptr<DMatrix> train,
91- linalg::Matrix<GradientPair> * in_gpair) = 0;
91+ GradientContainer * in_gpair) = 0;
9292 /* !
9393 * \brief evaluate the model for specific iteration using the configured metrics.
9494 * \param iter iteration number
0 commit comments