Skip to content

Commit 564df59

Browse files
authored
[breaking] [jvm-packages] Remove scala-implemented tracker. (#9045)
1 parent 42d100d commit 564df59

File tree

8 files changed

+9
-1585
lines changed

8 files changed

+9
-1585
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

+6-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2014-2022 by Contributors
2+
Copyright (c) 2014-2023 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ import scala.util.Random
2323
import scala.collection.JavaConverters._
2424

2525
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
26-
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
2726
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
2827
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
2928
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
4443
* Use a finite, non-zero timeout value to prevent tracker from
4544
* hanging indefinitely (in milliseconds)
4645
* (supported by "scala" implementation only.)
47-
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
48-
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
49-
* in Scala without Python components, and with full support of timeouts.
50-
* The Scala implementation is currently experimental, use at your own risk.
51-
*
5246
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
5347
* This is only needed if the host IP cannot be automatically guessed.
5448
* @param pythonExec The python executed path for Rabit Tracker,
5549
* which is only used for python implementation.
5650
*/
57-
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
51+
case class TrackerConf(workerConnectionTimeout: Long,
5852
hostIp: String = "", pythonExec: String = "")
5953

6054
object TrackerConf {
61-
def apply(): TrackerConf = TrackerConf(0L, "python")
55+
def apply(): TrackerConf = TrackerConf(0L)
6256
}
6357

6458
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@@ -349,11 +343,9 @@ object XGBoost extends Serializable {
349343

350344
/** visiable for testing */
351345
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
352-
val tracker: IRabitTracker = trackerConf.trackerImpl match {
353-
case "scala" => new RabitTracker(nWorkers)
354-
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
355-
case _ => new PyRabitTracker(nWorkers)
356-
}
346+
val tracker: IRabitTracker = new PyRabitTracker(
347+
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
348+
)
357349
tracker
358350
}
359351

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala

+3-118
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import scala.util.Random
2222

2323
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
2424
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
25-
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
2625
import ml.dmlc.xgboost4j.scala.DMatrix
2726
import org.scalatest.FunSuite
2827

@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
4039

4140
val paramMap = Map(
4241
"num_workers" -> numWorkers,
43-
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
42+
"tracker_conf" -> TrackerConf(0L, hostIp))
4443
val xgbExecParams = getXGBoostExecutionParams(paramMap)
4544
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
4645
tracker match {
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
5352

5453
val paramMap1 = Map(
5554
"num_workers" -> numWorkers,
56-
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
55+
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
5756
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
5857
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
5958
tracker1 match {
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
6665

6766
val paramMap2 = Map(
6867
"num_workers" -> numWorkers,
69-
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
68+
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
7069
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
7170
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
7271
tracker2 match {
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
7877
}
7978
}
8079

81-
test("training with Scala-implemented Rabit tracker") {
82-
val eval = new EvalError()
83-
val training = buildDataFrame(Classification.train)
84-
val testDM = new DMatrix(Classification.test.iterator)
85-
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
86-
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
87-
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
88-
val model = new XGBoostClassifier(paramMap).fit(training)
89-
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
90-
}
91-
92-
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
93-
val vectorLength = 100
94-
val rdd = sc.parallelize(
95-
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
96-
97-
val tracker = new ScalaRabitTracker(numWorkers)
98-
tracker.start(0)
99-
val trackerEnvs = tracker.getWorkerEnvs
100-
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
101-
102-
val rawData = rdd.mapPartitions { iter =>
103-
Iterator(iter.toArray)
104-
}.collect()
105-
106-
val maxVec = (0 until vectorLength).toArray.map { j =>
107-
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
108-
}
109-
110-
val allReduceResults = rdd.mapPartitions { iter =>
111-
Communicator.init(trackerEnvs)
112-
val arr = iter.toArray
113-
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
114-
Communicator.shutdown()
115-
Iterator(results)
116-
}.cache()
117-
118-
val sparkThread = new Thread() {
119-
override def run(): Unit = {
120-
allReduceResults.foreachPartition(() => _)
121-
val byPartitionResults = allReduceResults.collect()
122-
assert(byPartitionResults(0).length == vectorLength)
123-
collectedAllReduceResults.put(byPartitionResults(0))
124-
}
125-
}
126-
sparkThread.start()
127-
assert(tracker.waitFor(0L) == 0)
128-
sparkThread.join()
129-
130-
assert(collectedAllReduceResults.poll().sameElements(maxVec))
131-
}
132-
13380
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
13481
/*
13582
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
193140
assert(tracker.waitFor(0) != 0)
194141
}
195142

196-
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
197-
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
198-
199-
val tracker = new ScalaRabitTracker(numWorkers)
200-
tracker.start(0)
201-
val trackerEnvs = tracker.getWorkerEnvs
202-
203-
val workerCount: Int = numWorkers
204-
val dummyTasks = rdd.mapPartitions { iter =>
205-
Communicator.init(trackerEnvs)
206-
val index = iter.next()
207-
Thread.sleep(100 + index * 10)
208-
if (index == workerCount) {
209-
// kill the worker by throwing an exception
210-
throw new RuntimeException("Worker exception.")
211-
}
212-
Communicator.shutdown()
213-
Iterator(index)
214-
}.cache()
215-
216-
val sparkThread = new Thread() {
217-
override def run(): Unit = {
218-
// forces a Spark job.
219-
dummyTasks.foreachPartition(() => _)
220-
}
221-
}
222-
sparkThread.setUncaughtExceptionHandler(tracker)
223-
sparkThread.start()
224-
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
225-
}
226-
227-
test("test Scala RabitTracker's workerConnectionTimeout") {
228-
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
229-
230-
val tracker = new ScalaRabitTracker(numWorkers)
231-
tracker.start(500)
232-
val trackerEnvs = tracker.getWorkerEnvs
233-
234-
val dummyTasks = rdd.mapPartitions { iter =>
235-
val index = iter.next()
236-
// simulate that the first worker cannot connect to tracker due to network issues.
237-
if (index != 1) {
238-
Communicator.init(trackerEnvs)
239-
Thread.sleep(1000)
240-
Communicator.shutdown()
241-
}
242-
243-
Iterator(index)
244-
}.cache()
245-
246-
val sparkThread = new Thread() {
247-
override def run(): Unit = {
248-
// forces a Spark job.
249-
dummyTasks.foreachPartition(() => _)
250-
}
251-
}
252-
sparkThread.setUncaughtExceptionHandler(tracker)
253-
sparkThread.start()
254-
// should fail due to connection timeout
255-
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
256-
}
257-
258143
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
259144
" multiple times (ISSUE-4406)") {
260145
val paramMap = Map(

0 commit comments

Comments
 (0)