@@ -90,24 +90,23 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
90
90
val df = Seq ((1.0f , 2.0f , 0.0f ),
91
91
(2.0f , 3.0f , 1.0f )
92
92
).toDF(" c1" , " c2" , " label" )
93
- val classifier = new XGBoostClassifier ()
94
- assert(classifier.getPlugin.isDefined)
95
- assert(classifier.getPlugin.get.isEnabled(df) === expected)
93
+ assert(PluginUtils .getPlugin.isDefined)
94
+ assert(PluginUtils .getPlugin.get.isEnabled(df) === expected)
96
95
}
97
96
98
97
// spark.rapids.sql.enabled is not set explicitly, default to true
99
98
withSparkSession(new SparkConf (), spark => {
100
- checkIsEnabled(spark, true )
99
+ checkIsEnabled(spark, expected = true )
101
100
})
102
101
103
102
// set spark.rapids.sql.enabled to false
104
103
withCpuSparkSession() { spark =>
105
- checkIsEnabled(spark, false )
104
+ checkIsEnabled(spark, expected = false )
106
105
}
107
106
108
107
// set spark.rapids.sql.enabled to true
109
108
withGpuSparkSession() { spark =>
110
- checkIsEnabled(spark, true )
109
+ checkIsEnabled(spark, expected = true )
111
110
}
112
111
}
113
112
@@ -122,7 +121,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
122
121
).toDF(" c1" , " c2" , " weight" , " margin" , " label" , " other" )
123
122
val classifier = new XGBoostClassifier ()
124
123
125
- val plugin = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
124
+ val plugin = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
126
125
intercept[IllegalArgumentException ] {
127
126
plugin.validate(classifier, df)
128
127
}
@@ -156,9 +155,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
156
155
var classifier = new XGBoostClassifier ()
157
156
.setNumWorkers(3 )
158
157
.setFeaturesCol(features)
159
- assert(classifier .getPlugin.isDefined)
160
- assert(classifier .getPlugin.get.isInstanceOf [GpuXGBoostPlugin ])
161
- var out = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
158
+ assert(PluginUtils .getPlugin.isDefined)
159
+ assert(PluginUtils .getPlugin.get.isInstanceOf [GpuXGBoostPlugin ])
160
+ var out = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
162
161
.preprocess(classifier, df)
163
162
164
163
assert(out.schema.names.contains(" c1" ) && out.schema.names.contains(" c2" ))
@@ -172,7 +171,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
172
171
.setWeightCol(" weight" )
173
172
.setBaseMarginCol(" margin" )
174
173
.setDevice(" cuda" )
175
- out = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
174
+ out = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
176
175
.preprocess(classifier, df)
177
176
178
177
assert(out.schema.names.contains(" c1" ) && out.schema.names.contains(" c2" ))
@@ -207,7 +206,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
207
206
.setDevice(" cuda" )
208
207
.setMissing(missing)
209
208
210
- val rdd = classifier .getPlugin.get.buildRddWatches(classifier, df)
209
+ val rdd = PluginUtils .getPlugin.get.buildRddWatches(classifier, df)
211
210
val result = rdd.mapPartitions { iter =>
212
211
val watches = iter.next()
213
212
val size = watches.size
@@ -271,7 +270,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
271
270
.setMissing(missing)
272
271
.setEvalDataset(eval)
273
272
274
- val rdd = classifier .getPlugin.get.buildRddWatches(classifier, train)
273
+ val rdd = PluginUtils .getPlugin.get.buildRddWatches(classifier, train)
275
274
val result = rdd.mapPartitions { iter =>
276
275
val watches = iter.next()
277
276
val size = watches.size
@@ -324,7 +323,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
324
323
.setLabelCol(" label" )
325
324
.setDevice(" cuda" )
326
325
327
- assert(estimator .getPlugin.isDefined && estimator .getPlugin.get.isEnabled(df))
326
+ assert(PluginUtils .getPlugin.isDefined && PluginUtils .getPlugin.get.isEnabled(df))
328
327
329
328
val out = estimator.fit(df).transform(df)
330
329
// Transform should not discard the other columns of the transforming dataframe
@@ -528,7 +527,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
528
527
.setGroupCol(group)
529
528
.setDevice(" cuda" )
530
529
531
- val processedDf = ranker.getPlugin.get.asInstanceOf [GpuXGBoostPlugin ].preprocess(ranker, df)
530
+ val processedDf = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
531
+ .preprocess(ranker, df)
532
532
processedDf.rdd.foreachPartition { iter => {
533
533
var prevGroup = Int .MinValue
534
534
while (iter.hasNext) {
@@ -575,7 +575,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
575
575
// The fix has replaced repartition with repartitionByRange which will put the
576
576
// instances with same group into the same partition
577
577
val ranker = new XGBoostRanker ().setGroupCol(" group" ).setNumWorkers(num_workers)
578
- val processedDf = ranker.getPlugin.get.asInstanceOf [GpuXGBoostPlugin ].preprocess(ranker, df)
578
+ val processedDf = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
579
+ .preprocess(ranker, df)
579
580
val rows = processedDf
580
581
.select(" group" )
581
582
.mapPartitions { case iter =>
0 commit comments