8282 HasFeaturesCols ,
8383 HasQueryIdCol ,
8484)
85+ from .summary import XGBoostTrainingSummary
8586from .utils import (
8687 CommunicatorContext ,
8788 _get_default_params_from_func ,
@@ -704,8 +705,10 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
704705 """
705706 raise NotImplementedError ()
706707
707- def _create_pyspark_model (self , xgb_model : XGBModel ) -> "_SparkXGBModel" :
708- return self ._pyspark_model_cls ()(xgb_model )
708+ def _create_pyspark_model (
709+ self , xgb_model : XGBModel , training_summary : XGBoostTrainingSummary
710+ ) -> "_SparkXGBModel" :
711+ return self ._pyspark_model_cls ()(xgb_model , training_summary )
709712
710713 def _convert_to_sklearn_model (self , booster : bytearray , config : str ) -> XGBModel :
711714 xgb_sklearn_params = self ._gen_xgb_params_dict (
@@ -1148,7 +1151,7 @@ def _train_booster(
11481151 if dvalid is not None :
11491152 dval = [(dtrain , "training" ), (dvalid , "validation" )]
11501153 else :
1151- dval = None
1154+ dval = [( dtrain , "training" )]
11521155 booster = worker_train (
11531156 params = booster_params ,
11541157 dtrain = dtrain ,
@@ -1159,6 +1162,7 @@ def _train_booster(
11591162 context .barrier ()
11601163
11611164 if context .partitionId () == 0 :
1165+ yield pd .DataFrame ({"data" : [json .dumps (dict (evals_result ))]})
11621166 config = booster .save_config ()
11631167 yield pd .DataFrame ({"data" : [config ]})
11641168 booster_json = booster .save_raw ("json" ).decode ("utf-8" )
@@ -1167,7 +1171,7 @@ def _train_booster(
11671171 booster_chunk = booster_json [offset : offset + _MODEL_CHUNK_SIZE ]
11681172 yield pd .DataFrame ({"data" : [booster_chunk ]})
11691173
1170- def _run_job () -> Tuple [str , str ]:
1174+ def _run_job () -> Tuple [str , str , str ]:
11711175 rdd = (
11721176 dataset .mapInPandas (
11731177 _train_booster , # type: ignore
@@ -1179,7 +1183,7 @@ def _run_job() -> Tuple[str, str]:
11791183 rdd_with_resource = self ._try_stage_level_scheduling (rdd )
11801184 ret = rdd_with_resource .collect ()
11811185 data = [v [0 ] for v in ret ]
1182- return data [0 ], "" .join (data [1 :])
1186+ return data [0 ], data [ 1 ], "" .join (data [2 :])
11831187
11841188 get_logger (_LOG_TAG ).info (
11851189 "Running xgboost-%s on %s workers with"
@@ -1192,13 +1196,14 @@ def _run_job() -> Tuple[str, str]:
11921196 train_call_kwargs_params ,
11931197 dmatrix_kwargs ,
11941198 )
1195- (config , booster ) = _run_job ()
1199+ (evals_result , config , booster ) = _run_job ()
11961200 get_logger (_LOG_TAG ).info ("Finished xgboost training!" )
11971201
11981202 result_xgb_model = self ._convert_to_sklearn_model (
11991203 bytearray (booster , "utf-8" ), config
12001204 )
1201- spark_model = self ._create_pyspark_model (result_xgb_model )
1205+ training_summary = XGBoostTrainingSummary .from_metrics (json .loads (evals_result ))
1206+ spark_model = self ._create_pyspark_model (result_xgb_model , training_summary )
12021207 # According to pyspark ML convention, the model uid should be the same
12031208 # with estimator uid.
12041209 spark_model ._resetUid (self .uid )
@@ -1219,9 +1224,14 @@ def read(cls) -> "SparkXGBReader":
12191224
12201225
12211226class _SparkXGBModel (Model , _SparkXGBParams , MLReadable , MLWritable ):
1222- def __init__ (self , xgb_sklearn_model : Optional [XGBModel ] = None ) -> None :
1227+ def __init__ (
1228+ self ,
1229+ xgb_sklearn_model : Optional [XGBModel ] = None ,
1230+ training_summary : Optional [XGBoostTrainingSummary ] = None ,
1231+ ) -> None :
12231232 super ().__init__ ()
12241233 self ._xgb_sklearn_model = xgb_sklearn_model
1234+ self .training_summary = training_summary
12251235
12261236 @classmethod
12271237 def _xgb_cls (cls ) -> Type [XGBModel ]:
0 commit comments