Skip to content

Commit 766bfcc

Browse files
authored
[JVM-Packages] Allow XGBoost jvm package run on GPU without rapids (dmlc#11184)
1 parent 2e1626c commit 766bfcc

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2024 by Contributors
2+
Copyright (c) 2024-2025 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,23 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
9090
val df = Seq((1.0f, 2.0f, 0.0f),
9191
(2.0f, 3.0f, 1.0f)
9292
).toDF("c1", "c2", "label")
93-
val classifier = new XGBoostClassifier()
94-
assert(classifier.getPlugin.isDefined)
95-
assert(classifier.getPlugin.get.isEnabled(df) === expected)
93+
assert(PluginUtils.getPlugin.isDefined)
94+
assert(PluginUtils.getPlugin.get.isEnabled(df) === expected)
9695
}
9796

9897
// spark.rapids.sql.enabled is not set explicitly, default to true
9998
withSparkSession(new SparkConf(), spark => {
100-
checkIsEnabled(spark, true)
99+
checkIsEnabled(spark, expected = true)
101100
})
102101

103102
// set spark.rapids.sql.enabled to false
104103
withCpuSparkSession() { spark =>
105-
checkIsEnabled(spark, false)
104+
checkIsEnabled(spark, expected = false)
106105
}
107106

108107
// set spark.rapids.sql.enabled to true
109108
withGpuSparkSession() { spark =>
110-
checkIsEnabled(spark, true)
109+
checkIsEnabled(spark, expected = true)
111110
}
112111
}
113112

@@ -122,7 +121,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
122121
).toDF("c1", "c2", "weight", "margin", "label", "other")
123122
val classifier = new XGBoostClassifier()
124123

125-
val plugin = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
124+
val plugin = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
126125
intercept[IllegalArgumentException] {
127126
plugin.validate(classifier, df)
128127
}
@@ -156,9 +155,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
156155
var classifier = new XGBoostClassifier()
157156
.setNumWorkers(3)
158157
.setFeaturesCol(features)
159-
assert(classifier.getPlugin.isDefined)
160-
assert(classifier.getPlugin.get.isInstanceOf[GpuXGBoostPlugin])
161-
var out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
158+
assert(PluginUtils.getPlugin.isDefined)
159+
assert(PluginUtils.getPlugin.get.isInstanceOf[GpuXGBoostPlugin])
160+
var out = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
162161
.preprocess(classifier, df)
163162

164163
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
@@ -172,7 +171,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
172171
.setWeightCol("weight")
173172
.setBaseMarginCol("margin")
174173
.setDevice("cuda")
175-
out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
174+
out = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
176175
.preprocess(classifier, df)
177176

178177
assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
@@ -207,7 +206,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
207206
.setDevice("cuda")
208207
.setMissing(missing)
209208

210-
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df)
209+
val rdd = PluginUtils.getPlugin.get.buildRddWatches(classifier, df)
211210
val result = rdd.mapPartitions { iter =>
212211
val watches = iter.next()
213212
val size = watches.size
@@ -271,7 +270,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
271270
.setMissing(missing)
272271
.setEvalDataset(eval)
273272

274-
val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train)
273+
val rdd = PluginUtils.getPlugin.get.buildRddWatches(classifier, train)
275274
val result = rdd.mapPartitions { iter =>
276275
val watches = iter.next()
277276
val size = watches.size
@@ -324,7 +323,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
324323
.setLabelCol("label")
325324
.setDevice("cuda")
326325

327-
assert(estimator.getPlugin.isDefined && estimator.getPlugin.get.isEnabled(df))
326+
assert(PluginUtils.getPlugin.isDefined && PluginUtils.getPlugin.get.isEnabled(df))
328327

329328
val out = estimator.fit(df).transform(df)
330329
// Transform should not discard the other columns of the transforming dataframe
@@ -528,7 +527,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
528527
.setGroupCol(group)
529528
.setDevice("cuda")
530529

531-
val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
530+
val processedDf = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
531+
.preprocess(ranker, df)
532532
processedDf.rdd.foreachPartition { iter => {
533533
var prevGroup = Int.MinValue
534534
while (iter.hasNext) {
@@ -575,7 +575,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
575575
// The fix has replaced repartition with repartitionByRange which will put the
576576
// instances with same group into the same partition
577577
val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers)
578-
val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
578+
val processedDf = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
579+
.preprocess(ranker, df)
579580
val rows = processedDf
580581
.select("group")
581582
.mapPartitions { case iter =>

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2024 by Contributors
2+
Copyright (c) 2024-2025 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -66,7 +66,7 @@ private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoost
6666
}
6767
}
6868

69-
private[spark] trait PluginMixin {
69+
private[spark] object PluginUtils {
7070
// Find the XGBoostPlugin by ServiceLoader
7171
private val plugin: Option[XGBoostPlugin] = {
7272
val classLoader = Option(Thread.currentThread().getContextClassLoader)
@@ -85,18 +85,17 @@ private[spark] trait PluginMixin {
8585
}
8686

8787
/** Visible for testing */
88-
protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin
88+
def getPlugin: Option[XGBoostPlugin] = plugin
8989

90-
protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
90+
def isPluginEnabled(dataset: Dataset[_]): Boolean = {
9191
plugin.map(_.isEnabled(dataset)).getOrElse(false)
9292
}
9393
}
9494

9595
private[spark] trait XGBoostEstimator[
9696
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
9797
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
98-
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable
99-
with PluginMixin {
98+
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable {
10099

101100
protected val logger = LogFactory.getLog("XGBoostSpark")
102101

@@ -428,8 +427,8 @@ private[spark] trait XGBoostEstimator[
428427
protected def train(dataset: Dataset[_]): M = {
429428
validate(dataset)
430429

431-
val rdd = if (isPluginEnabled(dataset)) {
432-
getPlugin.get.buildRddWatches(this, dataset)
430+
val rdd = if (PluginUtils.isPluginEnabled(dataset)) {
431+
PluginUtils.getPlugin.get.buildRddWatches(this, dataset)
433432
} else {
434433
val (input, columnIndexes) = preprocess(dataset)
435434
toRdd(input, columnIndexes)
@@ -466,7 +465,7 @@ private[spark] case class PredictedColumns(
466465
* XGBoost base model
467466
*/
468467
private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable
469-
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin {
468+
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] {
470469

471470
protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"
472471

@@ -597,8 +596,8 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
597596
}
598597

599598
override def transform(dataset: Dataset[_]): DataFrame = {
600-
if (getPlugin.isDefined) {
601-
return getPlugin.get.transform(this, dataset)
599+
if (PluginUtils.isPluginEnabled(dataset)) {
600+
return PluginUtils.getPlugin.get.transform(this, dataset)
602601
}
603602
validateFeatureType(dataset.schema)
604603
val (schema, pred) = preprocess(dataset)

0 commit comments

Comments
 (0)