Skip to content

Commit 461d27c

Browse files
authored
[PySpark] Expose Training and Validation Metrics (#11133)
1 parent c3aa7fe commit 461d27c

File tree

3 files changed

+241
-17
lines changed

3 files changed

+241
-17
lines changed

python-package/xgboost/spark/core.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
HasFeaturesCols,
8383
HasQueryIdCol,
8484
)
85+
from .summary import XGBoostTrainingSummary
8586
from .utils import (
8687
CommunicatorContext,
8788
_get_default_params_from_func,
@@ -704,8 +705,10 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
704705
"""
705706
raise NotImplementedError()
706707

707-
def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
708-
return self._pyspark_model_cls()(xgb_model)
708+
def _create_pyspark_model(
709+
self, xgb_model: XGBModel, training_summary: XGBoostTrainingSummary
710+
) -> "_SparkXGBModel":
711+
return self._pyspark_model_cls()(xgb_model, training_summary)
709712

710713
def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
711714
xgb_sklearn_params = self._gen_xgb_params_dict(
@@ -1148,7 +1151,7 @@ def _train_booster(
11481151
if dvalid is not None:
11491152
dval = [(dtrain, "training"), (dvalid, "validation")]
11501153
else:
1151-
dval = None
1154+
dval = [(dtrain, "training")]
11521155
booster = worker_train(
11531156
params=booster_params,
11541157
dtrain=dtrain,
@@ -1159,6 +1162,7 @@ def _train_booster(
11591162
context.barrier()
11601163

11611164
if context.partitionId() == 0:
1165+
yield pd.DataFrame({"data": [json.dumps(dict(evals_result))]})
11621166
config = booster.save_config()
11631167
yield pd.DataFrame({"data": [config]})
11641168
booster_json = booster.save_raw("json").decode("utf-8")
@@ -1167,7 +1171,7 @@ def _train_booster(
11671171
booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE]
11681172
yield pd.DataFrame({"data": [booster_chunk]})
11691173

1170-
def _run_job() -> Tuple[str, str]:
1174+
def _run_job() -> Tuple[str, str, str]:
11711175
rdd = (
11721176
dataset.mapInPandas(
11731177
_train_booster, # type: ignore
@@ -1179,7 +1183,7 @@ def _run_job() -> Tuple[str, str]:
11791183
rdd_with_resource = self._try_stage_level_scheduling(rdd)
11801184
ret = rdd_with_resource.collect()
11811185
data = [v[0] for v in ret]
1182-
return data[0], "".join(data[1:])
1186+
return data[0], data[1], "".join(data[2:])
11831187

11841188
get_logger(_LOG_TAG).info(
11851189
"Running xgboost-%s on %s workers with"
@@ -1192,13 +1196,14 @@ def _run_job() -> Tuple[str, str]:
11921196
train_call_kwargs_params,
11931197
dmatrix_kwargs,
11941198
)
1195-
(config, booster) = _run_job()
1199+
(evals_result, config, booster) = _run_job()
11961200
get_logger(_LOG_TAG).info("Finished xgboost training!")
11971201

11981202
result_xgb_model = self._convert_to_sklearn_model(
11991203
bytearray(booster, "utf-8"), config
12001204
)
1201-
spark_model = self._create_pyspark_model(result_xgb_model)
1205+
training_summary = XGBoostTrainingSummary.from_metrics(json.loads(evals_result))
1206+
spark_model = self._create_pyspark_model(result_xgb_model, training_summary)
12021207
# According to pyspark ML convention, the model uid should be the same
12031208
# with estimator uid.
12041209
spark_model._resetUid(self.uid)
@@ -1219,9 +1224,14 @@ def read(cls) -> "SparkXGBReader":
12191224

12201225

12211226
class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
1222-
def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None:
1227+
def __init__(
1228+
self,
1229+
xgb_sklearn_model: Optional[XGBModel] = None,
1230+
training_summary: Optional[XGBoostTrainingSummary] = None,
1231+
) -> None:
12231232
super().__init__()
12241233
self._xgb_sklearn_model = xgb_sklearn_model
1234+
self.training_summary = training_summary
12251235

12261236
@classmethod
12271237
def _xgb_cls(cls) -> Type[XGBModel]:
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Xgboost training summary integration submodule."""
2+
3+
from dataclasses import dataclass, field
4+
from typing import Dict, List
5+
6+
7+
@dataclass
8+
class XGBoostTrainingSummary:
9+
"""
10+
A class that holds the training and validation objective history
11+
of an XGBoost model during its training process.
12+
"""
13+
14+
train_objective_history: Dict[str, List[float]] = field(default_factory=dict)
15+
validation_objective_history: Dict[str, List[float]] = field(default_factory=dict)
16+
17+
@staticmethod
18+
def from_metrics(
19+
metrics: Dict[str, Dict[str, List[float]]]
20+
) -> "XGBoostTrainingSummary":
21+
"""
22+
Create an XGBoostTrainingSummary instance from a nested dictionary of metrics.
23+
24+
Parameters
25+
----------
26+
metrics : dict of str to dict of str to list of float
27+
A dictionary containing training and validation metrics.
28+
Example format:
29+
{
30+
"training": {"logloss": [0.1, 0.08]},
31+
"validation": {"logloss": [0.12, 0.1]}
32+
}
33+
34+
Returns
35+
-------
36+
A new instance of XGBoostTrainingSummary.
37+
38+
"""
39+
train_objective_history = metrics.get("training", {})
40+
validation_objective_history = metrics.get("validation", {})
41+
return XGBoostTrainingSummary(
42+
train_objective_history, validation_objective_history
43+
)

0 commit comments

Comments
 (0)