@@ -22,7 +22,6 @@ import scala.util.Random
22
22
23
23
import ml .dmlc .xgboost4j .java .{Communicator , RabitTracker => PyRabitTracker }
24
24
import ml .dmlc .xgboost4j .java .IRabitTracker .TrackerStatus
25
- import ml .dmlc .xgboost4j .scala .rabit .{RabitTracker => ScalaRabitTracker }
26
25
import ml .dmlc .xgboost4j .scala .DMatrix
27
26
import org .scalatest .FunSuite
28
27
@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
40
39
41
40
val paramMap = Map (
42
41
" num_workers" -> numWorkers,
43
- " tracker_conf" -> TrackerConf (0L , " python " , hostIp))
42
+ " tracker_conf" -> TrackerConf (0L , hostIp))
44
43
val xgbExecParams = getXGBoostExecutionParams(paramMap)
45
44
val tracker = XGBoost .getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
46
45
tracker match {
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
53
52
54
53
val paramMap1 = Map (
55
54
" num_workers" -> numWorkers,
56
- " tracker_conf" -> TrackerConf (0L , " python " , " " , pythonExec))
55
+ " tracker_conf" -> TrackerConf (0L , " " , pythonExec))
57
56
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
58
57
val tracker1 = XGBoost .getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
59
58
tracker1 match {
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
66
65
67
66
val paramMap2 = Map (
68
67
" num_workers" -> numWorkers,
69
- " tracker_conf" -> TrackerConf (0L , " python " , hostIp, pythonExec))
68
+ " tracker_conf" -> TrackerConf (0L , hostIp, pythonExec))
70
69
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
71
70
val tracker2 = XGBoost .getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
72
71
tracker2 match {
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
78
77
}
79
78
}
80
79
81
- test(" training with Scala-implemented Rabit tracker" ) {
82
- val eval = new EvalError ()
83
- val training = buildDataFrame(Classification .train)
84
- val testDM = new DMatrix (Classification .test.iterator)
85
- val paramMap = Map (" eta" -> " 1" , " max_depth" -> " 6" ,
86
- " objective" -> " binary:logistic" , " num_round" -> 5 , " num_workers" -> numWorkers,
87
- " tracker_conf" -> TrackerConf (60 * 60 * 1000 , " scala" ))
88
- val model = new XGBoostClassifier (paramMap).fit(training)
89
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true ), testDM) < 0.1 )
90
- }
91
-
92
- test(" test Communicator allreduce to validate Scala-implemented Rabit tracker" ) {
93
- val vectorLength = 100
94
- val rdd = sc.parallelize(
95
- (1 to numWorkers * vectorLength).toArray.map { _ => Random .nextFloat() }, numWorkers).cache()
96
-
97
- val tracker = new ScalaRabitTracker (numWorkers)
98
- tracker.start(0 )
99
- val trackerEnvs = tracker.getWorkerEnvs
100
- val collectedAllReduceResults = new LinkedBlockingDeque [Array [Float ]]()
101
-
102
- val rawData = rdd.mapPartitions { iter =>
103
- Iterator (iter.toArray)
104
- }.collect()
105
-
106
- val maxVec = (0 until vectorLength).toArray.map { j =>
107
- (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
108
- }
109
-
110
- val allReduceResults = rdd.mapPartitions { iter =>
111
- Communicator .init(trackerEnvs)
112
- val arr = iter.toArray
113
- val results = Communicator .allReduce(arr, Communicator .OpType .MAX )
114
- Communicator .shutdown()
115
- Iterator (results)
116
- }.cache()
117
-
118
- val sparkThread = new Thread () {
119
- override def run (): Unit = {
120
- allReduceResults.foreachPartition(() => _)
121
- val byPartitionResults = allReduceResults.collect()
122
- assert(byPartitionResults(0 ).length == vectorLength)
123
- collectedAllReduceResults.put(byPartitionResults(0 ))
124
- }
125
- }
126
- sparkThread.start()
127
- assert(tracker.waitFor(0L ) == 0 )
128
- sparkThread.join()
129
-
130
- assert(collectedAllReduceResults.poll().sameElements(maxVec))
131
- }
132
-
133
80
test(" test Java RabitTracker wrapper's exception handling: it should not hang forever." ) {
134
81
/*
135
82
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
193
140
assert(tracker.waitFor(0 ) != 0 )
194
141
}
195
142
196
- test(" test Scala RabitTracker's exception handling: it should not hang forever." ) {
197
- val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
198
-
199
- val tracker = new ScalaRabitTracker (numWorkers)
200
- tracker.start(0 )
201
- val trackerEnvs = tracker.getWorkerEnvs
202
-
203
- val workerCount : Int = numWorkers
204
- val dummyTasks = rdd.mapPartitions { iter =>
205
- Communicator .init(trackerEnvs)
206
- val index = iter.next()
207
- Thread .sleep(100 + index * 10 )
208
- if (index == workerCount) {
209
- // kill the worker by throwing an exception
210
- throw new RuntimeException (" Worker exception." )
211
- }
212
- Communicator .shutdown()
213
- Iterator (index)
214
- }.cache()
215
-
216
- val sparkThread = new Thread () {
217
- override def run (): Unit = {
218
- // forces a Spark job.
219
- dummyTasks.foreachPartition(() => _)
220
- }
221
- }
222
- sparkThread.setUncaughtExceptionHandler(tracker)
223
- sparkThread.start()
224
- assert(tracker.waitFor(0L ) == TrackerStatus .FAILURE .getStatusCode)
225
- }
226
-
227
- test(" test Scala RabitTracker's workerConnectionTimeout" ) {
228
- val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
229
-
230
- val tracker = new ScalaRabitTracker (numWorkers)
231
- tracker.start(500 )
232
- val trackerEnvs = tracker.getWorkerEnvs
233
-
234
- val dummyTasks = rdd.mapPartitions { iter =>
235
- val index = iter.next()
236
- // simulate that the first worker cannot connect to tracker due to network issues.
237
- if (index != 1 ) {
238
- Communicator .init(trackerEnvs)
239
- Thread .sleep(1000 )
240
- Communicator .shutdown()
241
- }
242
-
243
- Iterator (index)
244
- }.cache()
245
-
246
- val sparkThread = new Thread () {
247
- override def run (): Unit = {
248
- // forces a Spark job.
249
- dummyTasks.foreachPartition(() => _)
250
- }
251
- }
252
- sparkThread.setUncaughtExceptionHandler(tracker)
253
- sparkThread.start()
254
- // should fail due to connection timeout
255
- assert(tracker.waitFor(0L ) == TrackerStatus .FAILURE .getStatusCode)
256
- }
257
-
258
143
test(" should allow the dataframe containing communicator calls to be partially evaluated for" +
259
144
" multiple times (ISSUE-4406)" ) {
260
145
val paramMap = Map (
0 commit comments