@@ -10,6 +10,7 @@ import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor, RandomF
10
10
import org .apache .spark .mllib .classification ._
11
11
import org .apache .spark .mllib .clustering .{KMeans , KMeansModel }
12
12
import org .apache .spark .mllib .linalg .{Vector , Vectors }
13
+ import org .apache .spark .mllib .optimization .{SquaredL2Updater , L1Updater , SimpleUpdater }
13
14
import org .apache .spark .mllib .recommendation .{ALS , MatrixFactorizationModel , Rating }
14
15
import org .apache .spark .mllib .regression ._
15
16
import org .apache .spark .mllib .tree .{GradientBoostedTrees , RandomForest }
@@ -140,46 +141,34 @@ class GLMRegressionTest(sc: SparkContext) extends GLMTests(sc) {
140
141
val regParam = doubleOptionValue(REG_PARAM )
141
142
val elasticNetParam = doubleOptionValue(ELASTIC_NET_PARAM )
142
143
val numIterations = intOptionValue(NUM_ITERATIONS )
143
- val optimizer = stringOptionValue(OPTIMIZER )
144
+ // val optimizer = stringOptionValue(OPTIMIZER) // ignore for now since it makes config hard to do
144
145
145
146
// Linear Regression only supports squared loss for now.
146
147
if (! Array (" l2" ).contains(loss)) {
147
148
throw new IllegalArgumentException (
148
149
s " GLMRegressionTest run with unknown loss ( $loss). Supported values: l2. " )
149
150
}
150
151
151
- if (Array (" sgd" ).contains(optimizer)) {
152
- if (! Array (" none" , " l1" , " l2" ).contains(regType)) {
153
- throw new IllegalArgumentException (
154
- s " GLMRegressionTest run with unknown regType ( $regType) with sgd. Supported values: none, l1, l2. " )
155
- }
156
- } else if (Array (" lbfgs" ).contains(optimizer)) {
157
- if (! Array (" elastic-net" ).contains(regType)) {
158
- throw new IllegalArgumentException (
159
- s " GLMRegressionTest run with unknown regType ( $regType) with lbfgs. Supported values: elastic-net. " )
160
- }
161
- } else {
162
- throw new IllegalArgumentException (
163
- s " GLMRegressionTest run with unknown optimizer ( $optimizer). Supported values: sgd, lbfgs. " )
164
- }
165
-
166
152
(loss, regType) match {
167
153
case (" l2" , " none" ) =>
168
154
val lr = new LinearRegressionWithSGD ().setIntercept(addIntercept = true )
169
- lr.optimizer.setNumIterations(numIterations).setStepSize(stepSize)
155
+ lr.optimizer.setNumIterations(numIterations).setStepSize(stepSize).setConvergenceTol( 0.0 )
170
156
lr.run(rdd)
171
157
case (" l2" , " l1" ) =>
172
158
val lasso = new LassoWithSGD ().setIntercept(addIntercept = true )
173
159
lasso.optimizer.setNumIterations(numIterations).setStepSize(stepSize).setRegParam(regParam)
160
+ .setConvergenceTol(0.0 )
174
161
lasso.run(rdd)
175
162
case (" l2" , " l2" ) =>
176
163
val rr = new RidgeRegressionWithSGD ().setIntercept(addIntercept = true )
177
164
rr.optimizer.setNumIterations(numIterations).setStepSize(stepSize).setRegParam(regParam)
165
+ .setConvergenceTol(0.0 )
178
166
rr.run(rdd)
179
167
case (" l2" , " elastic-net" ) =>
180
- println(" WARNING: Linear Regression with elastic-net in ML package uses LBFGS/OWLQN for optimization" +
181
- " which ignores stepSize and uses numIterations for maxIter in Spark 1.5." )
182
- val rr = new LinearRegression ().setElasticNetParam(elasticNetParam).setRegParam(regParam).setMaxIter(numIterations)
168
+ println(" WARNING: Linear Regression with elastic-net in ML package uses LBFGS/OWLQN for" +
169
+ " optimization which ignores stepSize and uses numIterations for maxIter in Spark 1.5." )
170
+ val rr = new LinearRegression ().setElasticNetParam(elasticNetParam).setRegParam(regParam)
171
+ .setMaxIter(numIterations)
183
172
val sqlContext = new SQLContext (rdd.context)
184
173
import sqlContext .implicits ._
185
174
val mlModel = rr.fit(rdd.toDF())
@@ -247,46 +236,68 @@ class GLMClassificationTest(sc: SparkContext) extends GLMTests(sc) {
247
236
s " GLMClassificationTest run with unknown loss ( $loss). Supported values: logistic, hinge. " )
248
237
}
249
238
250
- if (Array (" sgd" ).contains(optimizer)) {
251
- if (! Array (" none" , " l1" , " l2" ).contains(regType)) {
252
- throw new IllegalArgumentException (
253
- s " GLMRegressionTest run with unknown regType ( $regType) with sgd. Supported values: none, l1, l2. " )
254
- }
255
- } else if (Array (" lbfgs" ).contains(optimizer)) {
256
- if (! Array (" logistic" ).contains(loss)) {
257
- throw new IllegalArgumentException (
258
- s " GLMRegressionTest with lbfgs only supports logistic loss. " )
259
- }
260
- if (! Array (" none" , " elastic-net" ).contains(regType)) {
261
- throw new IllegalArgumentException (
262
- s " GLMRegressionTest run with unknown regType ( $regType) with lbfgs. Supported values: none, elastic-net. " )
239
+ if (regType == " elastic-net" ) { // use spark.ml
240
+ loss match {
241
+ case " logistic" =>
242
+ println(" WARNING: Logistic Regression with elastic-net in ML package uses LBFGS/OWLQN for optimization" +
243
+ " which ignores stepSize in Spark 1.5." )
244
+ val lor = new LogisticRegression ().setElasticNetParam(elasticNetParam).setRegParam(regParam)
245
+ .setMaxIter(numIterations)
246
+ val sqlContext = new SQLContext (rdd.context)
247
+ import sqlContext .implicits ._
248
+ val mlModel = lor.fit(rdd.toDF())
249
+ new LogisticRegressionModel (mlModel.weights, mlModel.intercept)
250
+ case _ =>
251
+ throw new IllegalArgumentException (
252
+ s " GLMClassificationTest given unsupported loss = $loss. " +
253
+ s " Note the set of supported combinations increases in later Spark versions. " )
263
254
}
264
255
} else {
265
- throw new IllegalArgumentException (
266
- s " GLMRegressionTest run with unknown optimizer ( $optimizer). Supported values: sgd, lbfgs. " )
267
- }
268
-
269
- (loss, regType, optimizer) match {
270
- case (" logistic" , " none" , " sgd" ) =>
271
- LogisticRegressionWithSGD .train(rdd, numIterations, stepSize)
272
- case (" logistic" , " none" , " lbfgs" ) =>
273
- println(" WARNING: LogisticRegressionWithLBFGS ignores numIterations, stepSize" +
274
- " in this Spark version." )
275
- new LogisticRegressionWithLBFGS ().run(rdd)
276
- case (" logistic" , " elastic-net" , _) =>
277
- println(" WARNING: Logistic Regression with elastic-net in ML package uses LBFGS/OWLQN for optimization" +
278
- " which ignores stepSize and uses numIterations for maxIter in Spark 1.5." )
279
- val lor = new LogisticRegression ().setElasticNetParam(elasticNetParam).setRegParam(regParam).setMaxIter(numIterations)
280
- val sqlContext = new SQLContext (rdd.context)
281
- import sqlContext .implicits ._
282
- val mlModel = lor.fit(rdd.toDF())
283
- new LogisticRegressionModel (mlModel.weights, mlModel.intercept)
284
- case (" hinge" , " l2" , " sgd" ) =>
285
- SVMWithSGD .train(rdd, numIterations, stepSize, regParam)
286
- case _ =>
287
- throw new IllegalArgumentException (
288
- s " GLMClassificationTest given incompatible (loss, regType) = ( $loss, $regType). " +
289
- s " Note the set of supported combinations increases in later Spark versions. " )
256
+ (loss, optimizer) match {
257
+ case (" logistic" , " sgd" ) =>
258
+ val lr = new LogisticRegressionWithSGD ()
259
+ lr.optimizer.setStepSize(stepSize).setNumIterations(numIterations).setConvergenceTol(0.0 )
260
+ regType match {
261
+ case " none" =>
262
+ lr.optimizer.setUpdater(new SimpleUpdater )
263
+ case " l1" =>
264
+ lr.optimizer.setUpdater(new L1Updater )
265
+ case " l2" =>
266
+ lr.optimizer.setUpdater(new SquaredL2Updater )
267
+ }
268
+ lr.run(rdd)
269
+ case (" logistic" , " lbfgs" ) =>
270
+ println(" WARNING: LogisticRegressionWithLBFGS ignores stepSize in this Spark version." )
271
+ val lr = new LogisticRegressionWithLBFGS ()
272
+ lr.optimizer.setNumIterations(numIterations).setConvergenceTol(0.0 )
273
+ regType match {
274
+ case " none" =>
275
+ lr.optimizer.setUpdater(new SimpleUpdater )
276
+ case " l1" =>
277
+ lr.optimizer.setUpdater(new L1Updater )
278
+ case " l2" =>
279
+ lr.optimizer.setUpdater(new SquaredL2Updater )
280
+ }
281
+ lr.run(rdd)
282
+ case (" hinge" , " sgd" ) =>
283
+ val svm = new SVMWithSGD ()
284
+ svm.optimizer.setNumIterations(numIterations).setStepSize(stepSize).setRegParam(regParam)
285
+ .setConvergenceTol(0.0 )
286
+ regType match {
287
+ case " none" =>
288
+ svm.optimizer.setUpdater(new SimpleUpdater )
289
+ case " l1" =>
290
+ svm.optimizer.setUpdater(new L1Updater )
291
+ case " l2" =>
292
+ svm.optimizer.setUpdater(new SquaredL2Updater )
293
+ }
294
+ svm.run(rdd)
295
+ case _ =>
296
+ throw new IllegalArgumentException (
297
+ s " GLMClassificationTest given incompatible (loss, regType) = ( $loss, $regType). " +
298
+ s " Supported combinations include: (elastic-net, _), (logistic, sgd), (logistic, lbfgs), (hinge, sgd). " +
299
+ s " Note the set of supported combinations increases in later Spark versions. " )
300
+ }
290
301
}
291
302
}
292
303
}
@@ -322,7 +333,7 @@ abstract class RecommendationTests(sc: SparkContext) extends PerfTest {
322
333
val implicitRatings : Boolean = booleanOptionValue(IMPLICIT )
323
334
324
335
val data = DataGenerator .generateRatings(sc, numUsers, numProducts,
325
- math.ceil( numRatings * 1.25 ).toLong , implicitRatings,numPartitions,seed)
336
+ numRatings, implicitRatings, numPartitions, seed)
326
337
327
338
rdd = data._1.cache()
328
339
testRdd = data._2
@@ -490,7 +501,7 @@ class ALSTest(sc: SparkContext) extends RecommendationTests(sc) {
490
501
val seed = intOptionValue(RANDOM_SEED ) + 12
491
502
492
503
new ALS ().setIterations(numIterations).setRank(rank).setSeed(seed).setLambda(regParam)
493
- .setBlocks(rdd.partitions.size ).run(rdd)
504
+ .setBlocks(rdd.partitions.length ).run(rdd)
494
505
}
495
506
}
496
507
@@ -627,7 +638,6 @@ class DecisionTreeTest(sc: SparkContext) extends DecisionTreeTests(sc) {
627
638
seed : Long ): (Array [RDD [LabeledPoint ]], Map [Int , Int ], Int ) = {
628
639
// Generic test options
629
640
val numPartitions : Int = intOptionValue(NUM_PARTITIONS )
630
- val testDataFraction : Double = getTestDataFraction
631
641
// Data dimensions and type
632
642
val numExamples : Long = longOptionValue(NUM_EXAMPLES )
633
643
val numFeatures : Int = intOptionValue(NUM_FEATURES )
@@ -642,7 +652,7 @@ class DecisionTreeTest(sc: SparkContext) extends DecisionTreeTests(sc) {
642
652
numFeatures, numPartitions, labelType,
643
653
fracCategoricalFeatures, fracBinaryFeatures, treeDepth, seed)
644
654
645
- val splits = rdd_.randomSplit(Array (1.0 - testDataFraction, testDataFraction ), seed)
655
+ val splits = rdd_.randomSplit(Array (0.8 , 0.2 ), seed)
646
656
(splits, categoricalFeaturesInfo_, labelType)
647
657
}
648
658
0 commit comments