Skip to content

Commit b05f309

Browse files
committed
[SPARK-32140][ML][PYSPARK] Add training summary to FMClassificationModel
### What changes were proposed in this pull request? Add training summary for FMClassificationModel... ### Why are the changes needed? so that user can get the training process status, such as loss value of each iteration and total iteration number. ### Does this PR introduce _any_ user-facing change? Yes FMClassificationModel.summary FMClassificationModel.evaluate ### How was this patch tested? new tests Closes apache#28960 from huaxingao/fm_summary. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Huaxin Gao <[email protected]>
1 parent cf22d94 commit b05f309

File tree

7 files changed

+257
-32
lines changed

7 files changed

+257
-32
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

+96-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3030
import org.apache.spark.mllib.linalg.{Vector => OldVector}
3131
import org.apache.spark.mllib.linalg.VectorImplicits._
3232
import org.apache.spark.rdd.RDD
33-
import org.apache.spark.sql.{Dataset, Row}
33+
import org.apache.spark.sql._
3434
import org.apache.spark.storage.StorageLevel
3535

3636
/**
@@ -212,14 +212,34 @@ class FMClassifier @Since("3.0.0") (
212212

213213
if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
214214

215-
val coefficients = trainImpl(data, numFeatures, LogisticLoss)
215+
val (coefficients, objectiveHistory) = trainImpl(data, numFeatures, LogisticLoss)
216216

217217
val (intercept, linear, factors) = splitCoefficients(
218218
coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))
219219

220220
if (handlePersistence) data.unpersist()
221221

222-
copyValues(new FMClassificationModel(uid, intercept, linear, factors))
222+
createModel(dataset, intercept, linear, factors, objectiveHistory)
223+
}
224+
225+
private def createModel(
226+
dataset: Dataset[_],
227+
intercept: Double,
228+
linear: Vector,
229+
factors: Matrix,
230+
objectiveHistory: Array[Double]): FMClassificationModel = {
231+
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
232+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
233+
234+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
235+
val summary = new FMClassificationTrainingSummaryImpl(
236+
summaryModel.transform(dataset),
237+
probabilityColName,
238+
predictionColName,
239+
$(labelCol),
240+
weightColName,
241+
objectiveHistory)
242+
model.setSummary(Some(summary))
223243
}
224244

225245
@Since("3.0.0")
@@ -243,14 +263,36 @@ class FMClassificationModel private[classification] (
243263
@Since("3.0.0") val linear: Vector,
244264
@Since("3.0.0") val factors: Matrix)
245265
extends ProbabilisticClassificationModel[Vector, FMClassificationModel]
246-
with FMClassifierParams with MLWritable {
266+
with FMClassifierParams with MLWritable
267+
with HasTrainingSummary[FMClassificationTrainingSummary]{
247268

248269
@Since("3.0.0")
249270
override val numClasses: Int = 2
250271

251272
@Since("3.0.0")
252273
override val numFeatures: Int = linear.size
253274

275+
/**
276+
* Gets summary of model on training set. An exception is thrown
277+
* if `hasSummary` is false.
278+
*/
279+
@Since("3.1.0")
280+
override def summary: FMClassificationTrainingSummary = super.summary
281+
282+
/**
283+
* Evaluates the model on a test dataset.
284+
*
285+
* @param dataset Test dataset to evaluate model on.
286+
*/
287+
@Since("3.1.0")
288+
def evaluate(dataset: Dataset[_]): FMClassificationSummary = {
289+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
290+
// Handle possible missing or invalid probability or prediction columns
291+
val (summaryModel, probability, predictionColName) = findSummaryModel()
292+
new FMClassificationSummaryImpl(summaryModel.transform(dataset),
293+
probability, predictionColName, $(labelCol), weightColName)
294+
}
295+
254296
@Since("3.0.0")
255297
override def predictRaw(features: Vector): Vector = {
256298
val rawPrediction = getRawPrediction(features, intercept, linear, factors)
@@ -328,3 +370,53 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
328370
}
329371
}
330372
}
373+
374+
/**
375+
* Abstraction for FMClassifier results for a given model.
376+
*/
377+
sealed trait FMClassificationSummary extends BinaryClassificationSummary
378+
379+
/**
380+
* Abstraction for FMClassifier training results.
381+
*/
382+
sealed trait FMClassificationTrainingSummary extends FMClassificationSummary with TrainingSummary
383+
384+
/**
385+
* FMClassifier results for a given model.
386+
*
387+
* @param predictions dataframe output by the model's `transform` method.
388+
* @param scoreCol field in "predictions" which gives the probability of each instance.
389+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
390+
* double.
391+
* @param labelCol field in "predictions" which gives the true label of each instance.
392+
* @param weightCol field in "predictions" which gives the weight of each instance.
393+
*/
394+
private class FMClassificationSummaryImpl(
395+
@transient override val predictions: DataFrame,
396+
override val scoreCol: String,
397+
override val predictionCol: String,
398+
override val labelCol: String,
399+
override val weightCol: String)
400+
extends FMClassificationSummary
401+
402+
/**
403+
* FMClassifier training results.
404+
*
405+
* @param predictions dataframe output by the model's `transform` method.
406+
* @param scoreCol field in "predictions" which gives the probability of each instance.
407+
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
408+
* double.
409+
* @param labelCol field in "predictions" which gives the true label of each instance.
410+
* @param weightCol field in "predictions" which gives the weight of each instance.
411+
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
412+
*/
413+
private class FMClassificationTrainingSummaryImpl(
414+
predictions: DataFrame,
415+
scoreCol: String,
416+
predictionCol: String,
417+
labelCol: String,
418+
weightCol: String,
419+
override val objectiveHistory: Array[Double])
420+
extends FMClassificationSummaryImpl(
421+
predictions, scoreCol, predictionCol, labelCol, weightCol)
422+
with FMClassificationTrainingSummary

mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import org.apache.spark.storage.StorageLevel
4747
*/
4848
private[ml] trait FactorizationMachinesParams extends PredictorParams
4949
with HasMaxIter with HasStepSize with HasTol with HasSolver with HasSeed
50-
with HasFitIntercept with HasRegParam {
50+
with HasFitIntercept with HasRegParam with HasWeightCol {
5151

5252
/**
5353
* Param for dimensionality of the factors (&gt;= 0)
@@ -134,7 +134,7 @@ private[ml] trait FactorizationMachines extends FactorizationMachinesParams {
134134
data: RDD[(Double, OldVector)],
135135
numFeatures: Int,
136136
loss: String
137-
): Vector = {
137+
): (Vector, Array[Double]) = {
138138

139139
// initialize coefficients
140140
val initialCoefficients = initCoefficients(numFeatures)
@@ -151,8 +151,8 @@ private[ml] trait FactorizationMachines extends FactorizationMachinesParams {
151151
.setRegParam($(regParam))
152152
.setMiniBatchFraction($(miniBatchFraction))
153153
.setConvergenceTol($(tol))
154-
val coefficients = optimizer.optimize(data, initialCoefficients)
155-
coefficients.asML
154+
val (coefficients, lossHistory) = optimizer.optimizeWithLossReturned(data, initialCoefficients)
155+
(coefficients.asML, lossHistory)
156156
}
157157
}
158158

@@ -421,7 +421,7 @@ class FMRegressor @Since("3.0.0") (
421421

422422
if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
423423

424-
val coefficients = trainImpl(data, numFeatures, SquaredError)
424+
val (coefficients, _) = trainImpl(data, numFeatures, SquaredError)
425425

426426
val (intercept, linear, factors) = splitCoefficients(
427427
coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

+29-16
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,20 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va
129129
* @return solution vector
130130
*/
131131
def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
132-
val (weights, _) = GradientDescent.runMiniBatchSGD(
132+
val (weights, _) = optimizeWithLossReturned(data, initialWeights)
133+
weights
134+
}
135+
136+
/**
137+
* Runs gradient descent on the given training data.
138+
* @param data training data
139+
* @param initialWeights initial weights
140+
* @return solution vector and loss value in an array
141+
*/
142+
def optimizeWithLossReturned(
143+
data: RDD[(Double, Vector)],
144+
initialWeights: Vector): (Vector, Array[Double]) = {
145+
GradientDescent.runMiniBatchSGD(
133146
data,
134147
gradient,
135148
updater,
@@ -139,7 +152,6 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va
139152
miniBatchFraction,
140153
initialWeights,
141154
convergenceTol)
142-
weights
143155
}
144156

145157
}
@@ -195,7 +207,7 @@ object GradientDescent extends Logging {
195207
s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction")
196208
}
197209

198-
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
210+
val stochasticLossHistory = new ArrayBuffer[Double](numIterations + 1)
199211
// Record previous weight and current one to calculate solution vector difference
200212

201213
var previousWeights: Option[Vector] = None
@@ -226,7 +238,7 @@ object GradientDescent extends Logging {
226238

227239
var converged = false // indicates whether converged based on convergenceTol
228240
var i = 1
229-
while (!converged && i <= numIterations) {
241+
while (!converged && (i <= numIterations + 1)) {
230242
val bcWeights = data.context.broadcast(weights)
231243
// Sample a subset (fraction miniBatchFraction) of the total data
232244
// compute and sum up the subgradients on this subset (this is one map-reduce)
@@ -249,17 +261,19 @@ object GradientDescent extends Logging {
249261
* and regVal is the regularization value computed in the previous iteration as well.
250262
*/
251263
stochasticLossHistory += lossSum / miniBatchSize + regVal
252-
val update = updater.compute(
253-
weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
254-
stepSize, i, regParam)
255-
weights = update._1
256-
regVal = update._2
257-
258-
previousWeights = currentWeights
259-
currentWeights = Some(weights)
260-
if (previousWeights != None && currentWeights != None) {
261-
converged = isConverged(previousWeights.get,
262-
currentWeights.get, convergenceTol)
264+
if (i != (numIterations + 1)) {
265+
val update = updater.compute(
266+
weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
267+
stepSize, i, regParam)
268+
weights = update._1
269+
regVal = update._2
270+
271+
previousWeights = currentWeights
272+
currentWeights = Some(weights)
273+
if (previousWeights != None && currentWeights != None) {
274+
converged = isConverged(previousWeights.get,
275+
currentWeights.get, convergenceTol)
276+
}
263277
}
264278
} else {
265279
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
@@ -271,7 +285,6 @@ object GradientDescent extends Logging {
271285
stochasticLossHistory.takeRight(10).mkString(", ")))
272286

273287
(weights, stochasticLossHistory.toArray)
274-
275288
}
276289

277290
/**

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
136136
}
137137

138138
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
139-
val (weights, _) = LBFGS.runLBFGS(
139+
val (weights, _) = optimizeWithLossReturned(data, initialWeights)
140+
weights
141+
}
142+
143+
def optimizeWithLossReturned(
144+
data: RDD[(Double, Vector)],
145+
initialWeights: Vector): (Vector, Array[Double]) = {
146+
LBFGS.runLBFGS(
140147
data,
141148
gradient,
142149
updater,
@@ -145,9 +152,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
145152
maxNumIterations,
146153
regParam,
147154
initialWeights)
148-
weights
149155
}
150-
151156
}
152157

153158
/**

mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala

+26
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,32 @@ class FMClassifierSuite extends MLTest with DefaultReadWriteTest {
194194
testPredictionModelSinglePrediction(fmModel, smallBinaryDataset)
195195
}
196196

197+
test("summary and training summary") {
198+
val fm = new FMClassifier()
199+
val model = fm.setMaxIter(5).fit(smallBinaryDataset)
200+
201+
val summary = model.evaluate(smallBinaryDataset)
202+
203+
assert(model.summary.accuracy === summary.accuracy)
204+
assert(model.summary.weightedPrecision === summary.weightedPrecision)
205+
assert(model.summary.weightedRecall === summary.weightedRecall)
206+
assert(model.summary.pr.collect() === summary.pr.collect())
207+
assert(model.summary.roc.collect() === summary.roc.collect())
208+
assert(model.summary.areaUnderROC === summary.areaUnderROC)
209+
}
210+
211+
test("FMClassifier training summary totalIterations") {
212+
Seq(1, 5, 10, 20, 100).foreach { maxIter =>
213+
val trainer = new FMClassifier().setMaxIter(maxIter)
214+
val model = trainer.fit(smallBinaryDataset)
215+
if (maxIter == 1) {
216+
assert(model.summary.totalIterations === maxIter)
217+
} else {
218+
assert(model.summary.totalIterations <= maxIter)
219+
}
220+
}
221+
}
222+
197223
test("read/write") {
198224
def checkModelData(
199225
model: FMClassificationModel,

python/pyspark/ml/classification.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
'NaiveBayes', 'NaiveBayesModel',
5353
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
5454
'OneVsRest', 'OneVsRestModel',
55-
'FMClassifier', 'FMClassificationModel']
55+
'FMClassifier', 'FMClassificationModel', 'FMClassificationSummary',
56+
'FMClassificationTrainingSummary']
5657

5758

5859
class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
@@ -3226,7 +3227,7 @@ def setRegParam(self, value):
32263227

32273228

32283229
class FMClassificationModel(_JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
3229-
JavaMLWritable, JavaMLReadable):
3230+
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
32303231
"""
32313232
Model fitted by :class:`FMClassifier`.
32323233
@@ -3257,6 +3258,49 @@ def factors(self):
32573258
"""
32583259
return self._call_java("factors")
32593260

3261+
@since("3.1.0")
3262+
def summary(self):
3263+
"""
3264+
Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
3265+
trained on the training set. An exception is thrown if `trainingSummary is None`.
3266+
"""
3267+
if self.hasSummary:
3268+
return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
3269+
else:
3270+
raise RuntimeError("No training summary available for this %s" %
3271+
self.__class__.__name__)
3272+
3273+
@since("3.1.0")
3274+
def evaluate(self, dataset):
3275+
"""
3276+
Evaluates the model on a test dataset.
3277+
3278+
:param dataset:
3279+
Test dataset to evaluate model on, where dataset is an
3280+
instance of :py:class:`pyspark.sql.DataFrame`
3281+
"""
3282+
if not isinstance(dataset, DataFrame):
3283+
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
3284+
java_fm_summary = self._call_java("evaluate", dataset)
3285+
return FMClassificationSummary(java_fm_summary)
3286+
3287+
3288+
class FMClassificationSummary(_BinaryClassificationSummary):
3289+
"""
3290+
Abstraction for FMClassifier Results for a given model.
3291+
.. versionadded:: 3.1.0
3292+
"""
3293+
pass
3294+
3295+
3296+
@inherit_doc
3297+
class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary):
3298+
"""
3299+
Abstraction for FMClassifier Training results.
3300+
.. versionadded:: 3.1.0
3301+
"""
3302+
pass
3303+
32603304

32613305
if __name__ == "__main__":
32623306
import doctest

0 commit comments

Comments
 (0)