Skip to content

Commit 8a4378c

Browse files
committed
[SPARK-29686][ML] LinearSVC should persist instances if needed
### What changes were proposed in this pull request? persist the input if needed ### Why are the changes needed? training with non-cached dataset will hurt performance ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing tests Closes apache#26344 from zhengruifeng/linear_svc_cache. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent ae7450d commit 8a4378c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3737
import org.apache.spark.mllib.linalg.VectorImplicits._
3838
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
3939
import org.apache.spark.sql.{Dataset, Row}
40+
import org.apache.spark.storage.StorageLevel
4041

4142
/** Params for linear SVM Classifier. */
4243
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
@@ -159,7 +160,10 @@ class LinearSVC @Since("2.2.0") (
159160
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
160161

161162
override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
163+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
164+
162165
val instances = extractInstances(dataset)
166+
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
163167

164168
instr.logPipelineStage(this)
165169
instr.logDataset(dataset)
@@ -268,6 +272,8 @@ class LinearSVC @Since("2.2.0") (
268272
(Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result())
269273
}
270274

275+
if (handlePersistence) instances.unpersist()
276+
271277
copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector))
272278
}
273279
}

0 commit comments

Comments
 (0)