@@ -90,24 +90,23 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
9090 val df = Seq ((1.0f , 2.0f , 0.0f ),
9191 (2.0f , 3.0f , 1.0f )
9292 ).toDF(" c1" , " c2" , " label" )
93- val classifier = new XGBoostClassifier ()
94- assert(classifier.getPlugin.isDefined)
95- assert(classifier.getPlugin.get.isEnabled(df) === expected)
93+ assert(PluginUtils .getPlugin.isDefined)
94+ assert(PluginUtils .getPlugin.get.isEnabled(df) === expected)
9695 }
9796
9897 // spark.rapids.sql.enabled is not set explicitly, default to true
9998 withSparkSession(new SparkConf (), spark => {
100- checkIsEnabled(spark, true )
99+ checkIsEnabled(spark, expected = true )
101100 })
102101
103102 // set spark.rapids.sql.enabled to false
104103 withCpuSparkSession() { spark =>
105- checkIsEnabled(spark, false )
104+ checkIsEnabled(spark, expected = false )
106105 }
107106
108107 // set spark.rapids.sql.enabled to true
109108 withGpuSparkSession() { spark =>
110- checkIsEnabled(spark, true )
109+ checkIsEnabled(spark, expected = true )
111110 }
112111 }
113112
@@ -122,7 +121,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
122121 ).toDF(" c1" , " c2" , " weight" , " margin" , " label" , " other" )
123122 val classifier = new XGBoostClassifier ()
124123
125- val plugin = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
124+ val plugin = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
126125 intercept[IllegalArgumentException ] {
127126 plugin.validate(classifier, df)
128127 }
@@ -156,9 +155,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
156155 var classifier = new XGBoostClassifier ()
157156 .setNumWorkers(3 )
158157 .setFeaturesCol(features)
159- assert(classifier .getPlugin.isDefined)
160- assert(classifier .getPlugin.get.isInstanceOf [GpuXGBoostPlugin ])
161- var out = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
158+ assert(PluginUtils .getPlugin.isDefined)
159+ assert(PluginUtils .getPlugin.get.isInstanceOf [GpuXGBoostPlugin ])
160+ var out = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
162161 .preprocess(classifier, df)
163162
164163 assert(out.schema.names.contains(" c1" ) && out.schema.names.contains(" c2" ))
@@ -172,7 +171,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
172171 .setWeightCol(" weight" )
173172 .setBaseMarginCol(" margin" )
174173 .setDevice(" cuda" )
175- out = classifier .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
174+ out = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
176175 .preprocess(classifier, df)
177176
178177 assert(out.schema.names.contains(" c1" ) && out.schema.names.contains(" c2" ))
@@ -207,7 +206,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
207206 .setDevice(" cuda" )
208207 .setMissing(missing)
209208
210- val rdd = classifier .getPlugin.get.buildRddWatches(classifier, df)
209+ val rdd = PluginUtils .getPlugin.get.buildRddWatches(classifier, df)
211210 val result = rdd.mapPartitions { iter =>
212211 val watches = iter.next()
213212 val size = watches.size
@@ -271,7 +270,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
271270 .setMissing(missing)
272271 .setEvalDataset(eval)
273272
274- val rdd = classifier .getPlugin.get.buildRddWatches(classifier, train)
273+ val rdd = PluginUtils .getPlugin.get.buildRddWatches(classifier, train)
275274 val result = rdd.mapPartitions { iter =>
276275 val watches = iter.next()
277276 val size = watches.size
@@ -324,7 +323,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
324323 .setLabelCol(" label" )
325324 .setDevice(" cuda" )
326325
327- assert(estimator .getPlugin.isDefined && estimator .getPlugin.get.isEnabled(df))
326+ assert(PluginUtils .getPlugin.isDefined && PluginUtils .getPlugin.get.isEnabled(df))
328327
329328 val out = estimator.fit(df).transform(df)
330329 // Transform should not discard the other columns of the transforming dataframe
@@ -528,7 +527,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
528527 .setGroupCol(group)
529528 .setDevice(" cuda" )
530529
531- val processedDf = ranker.getPlugin.get.asInstanceOf [GpuXGBoostPlugin ].preprocess(ranker, df)
530+ val processedDf = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
531+ .preprocess(ranker, df)
532532 processedDf.rdd.foreachPartition { iter => {
533533 var prevGroup = Int .MinValue
534534 while (iter.hasNext) {
@@ -575,7 +575,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
575575 // The fix has replaced repartition with repartitionByRange which will put the
576576 // instances with same group into the same partition
577577 val ranker = new XGBoostRanker ().setGroupCol(" group" ).setNumWorkers(num_workers)
578- val processedDf = ranker.getPlugin.get.asInstanceOf [GpuXGBoostPlugin ].preprocess(ranker, df)
578+ val processedDf = PluginUtils .getPlugin.get.asInstanceOf [GpuXGBoostPlugin ]
579+ .preprocess(ranker, df)
579580 val rows = processedDf
580581 .select(" group" )
581582 .mapPartitions { case iter =>
0 commit comments