@@ -30,7 +30,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
30
30
import org .apache .spark .mllib .linalg .{Vector => OldVector }
31
31
import org .apache .spark .mllib .linalg .VectorImplicits ._
32
32
import org .apache .spark .rdd .RDD
33
- import org .apache .spark .sql .{ Dataset , Row }
33
+ import org .apache .spark .sql ._
34
34
import org .apache .spark .storage .StorageLevel
35
35
36
36
/**
@@ -212,14 +212,34 @@ class FMClassifier @Since("3.0.0") (
212
212
213
213
if (handlePersistence) data.persist(StorageLevel .MEMORY_AND_DISK )
214
214
215
- val coefficients = trainImpl(data, numFeatures, LogisticLoss )
215
+ val ( coefficients, objectiveHistory) = trainImpl(data, numFeatures, LogisticLoss )
216
216
217
217
val (intercept, linear, factors) = splitCoefficients(
218
218
coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))
219
219
220
220
if (handlePersistence) data.unpersist()
221
221
222
- copyValues(new FMClassificationModel (uid, intercept, linear, factors))
222
+ createModel(dataset, intercept, linear, factors, objectiveHistory)
223
+ }
224
+
225
+ private def createModel (
226
+ dataset : Dataset [_],
227
+ intercept : Double ,
228
+ linear : Vector ,
229
+ factors : Matrix ,
230
+ objectiveHistory : Array [Double ]): FMClassificationModel = {
231
+ val model = copyValues(new FMClassificationModel (uid, intercept, linear, factors))
232
+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
233
+
234
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
235
+ val summary = new FMClassificationTrainingSummaryImpl (
236
+ summaryModel.transform(dataset),
237
+ probabilityColName,
238
+ predictionColName,
239
+ $(labelCol),
240
+ weightColName,
241
+ objectiveHistory)
242
+ model.setSummary(Some (summary))
223
243
}
224
244
225
245
@ Since (" 3.0.0" )
@@ -243,14 +263,36 @@ class FMClassificationModel private[classification] (
243
263
@ Since (" 3.0.0" ) val linear : Vector ,
244
264
@ Since (" 3.0.0" ) val factors : Matrix )
245
265
extends ProbabilisticClassificationModel [Vector , FMClassificationModel ]
246
- with FMClassifierParams with MLWritable {
266
+ with FMClassifierParams with MLWritable
267
+ with HasTrainingSummary [FMClassificationTrainingSummary ]{
247
268
248
269
@ Since (" 3.0.0" )
249
270
override val numClasses : Int = 2
250
271
251
272
@ Since (" 3.0.0" )
252
273
override val numFeatures : Int = linear.size
253
274
275
+ /**
276
+ * Gets summary of model on training set. An exception is thrown
277
+ * if `hasSummary` is false.
278
+ */
279
+ @ Since (" 3.1.0" )
280
+ override def summary : FMClassificationTrainingSummary = super .summary
281
+
282
+ /**
283
+ * Evaluates the model on a test dataset.
284
+ *
285
+ * @param dataset Test dataset to evaluate model on.
286
+ */
287
+ @ Since (" 3.1.0" )
288
+ def evaluate (dataset : Dataset [_]): FMClassificationSummary = {
289
+ val weightColName = if (! isDefined(weightCol)) " weightCol" else $(weightCol)
290
+ // Handle possible missing or invalid probability or prediction columns
291
+ val (summaryModel, probability, predictionColName) = findSummaryModel()
292
+ new FMClassificationSummaryImpl (summaryModel.transform(dataset),
293
+ probability, predictionColName, $(labelCol), weightColName)
294
+ }
295
+
254
296
@ Since (" 3.0.0" )
255
297
override def predictRaw (features : Vector ): Vector = {
256
298
val rawPrediction = getRawPrediction(features, intercept, linear, factors)
@@ -328,3 +370,53 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
328
370
}
329
371
}
330
372
}
373
+
374
+ /**
375
+ * Abstraction for FMClassifier results for a given model.
376
+ */
377
+ sealed trait FMClassificationSummary extends BinaryClassificationSummary
378
+
379
+ /**
380
+ * Abstraction for FMClassifier training results.
381
+ */
382
+ sealed trait FMClassificationTrainingSummary extends FMClassificationSummary with TrainingSummary
383
+
384
+ /**
385
+ * FMClassifier results for a given model.
386
+ *
387
+ * @param predictions dataframe output by the model's `transform` method.
388
+ * @param scoreCol field in "predictions" which gives the probability of each instance.
389
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
390
+ * double.
391
+ * @param labelCol field in "predictions" which gives the true label of each instance.
392
+ * @param weightCol field in "predictions" which gives the weight of each instance.
393
+ */
394
+ private class FMClassificationSummaryImpl (
395
+ @ transient override val predictions : DataFrame ,
396
+ override val scoreCol : String ,
397
+ override val predictionCol : String ,
398
+ override val labelCol : String ,
399
+ override val weightCol : String )
400
+ extends FMClassificationSummary
401
+
402
+ /**
403
+ * FMClassifier training results.
404
+ *
405
+ * @param predictions dataframe output by the model's `transform` method.
406
+ * @param scoreCol field in "predictions" which gives the probability of each instance.
407
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
408
+ * double.
409
+ * @param labelCol field in "predictions" which gives the true label of each instance.
410
+ * @param weightCol field in "predictions" which gives the weight of each instance.
411
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
412
+ */
413
+ private class FMClassificationTrainingSummaryImpl (
414
+ predictions : DataFrame ,
415
+ scoreCol : String ,
416
+ predictionCol : String ,
417
+ labelCol : String ,
418
+ weightCol : String ,
419
+ override val objectiveHistory : Array [Double ])
420
+ extends FMClassificationSummaryImpl (
421
+ predictions, scoreCol, predictionCol, labelCol, weightCol)
422
+ with FMClassificationTrainingSummary
0 commit comments