Skip to content

Commit 0280f37

Browse files
committed
add getPreferredLocations
1 parent 5dd370f commit 0280f37

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

src/com/redislabs/provider/RedisConfig.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.redislabs.provider
33

44
import com.redislabs.provider.redis.NodesInfo._
55

6-
class RedisConfig(ip: String, port: Int) extends Serializable {
6+
class RedisConfig(val ip: String, val port: Int) extends Serializable {
77
val nodes: java.util.ArrayList[(String, Int)] = new java.util.ArrayList[(String, Int)]
88

99
getNodes((ip, port)).foreach(x => nodes.add((x._1, x._2)))

src/com/redislabs/provider/redis/rdd/RedisRDD.scala

+47-3
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,56 @@ class RedisKeysRDD(sc: SparkContext,
120120
val partitionNum: Int = 3)
121121
extends RDD[String](sc, Seq.empty) with Logging with Keys {
122122

123+
override protected def getPreferredLocations(split: Partition): Seq[String] = {
124+
Seq(split.asInstanceOf[RedisPartition].redisConfig.ip)
125+
}
126+
127+
private def scaleHostsWithPartitionNum(): Seq[(String, Int, Int, Int)] = {
128+
def split(host: (String, Int, Int, Int), cnt: Int) = {
129+
val start = host._3
130+
val end = host._4
131+
val range = (end - start) / cnt
132+
(0 until cnt).map(i => {
133+
(host._1,
134+
host._2,
135+
if (i == 0) start else (start + range * i + 1),
136+
if (i != cnt - 1) (start + range * (i + 1)) else end)
137+
})
138+
}
139+
140+
val hosts = com.redislabs.provider.redis.NodesInfo.getHosts(redisNode)
141+
if (hosts.size == partitionNum)
142+
hosts
143+
else if (hosts.size < partitionNum) {
144+
val presExtCnt = partitionNum / hosts.size
145+
val lastExtCnt = if (presExtCnt * hosts.size < partitionNum) (presExtCnt + 1) else presExtCnt
146+
hosts.zipWithIndex.flatMap{
147+
case(host, idx) => {
148+
split(host, if (idx == hosts.size - 1) lastExtCnt else presExtCnt)
149+
}
150+
}
151+
}
152+
else {
153+
val presExtCnt = hosts.size / partitionNum
154+
val lastExtCnt = if (presExtCnt * partitionNum < hosts.size) (presExtCnt + 1) else presExtCnt
155+
(0 until partitionNum).map{
156+
idx => {
157+
val ip = hosts(idx * presExtCnt)._1
158+
val port = hosts(idx * presExtCnt)._2
159+
val start = hosts(idx * presExtCnt)._3
160+
val end = hosts(if (idx == partitionNum - 1) (hosts.size-1) else ((idx + 1) * presExtCnt - 1))._4
161+
(ip, port, start, end)
162+
}
163+
}
164+
}
165+
}
166+
123167
override protected def getPartitions: Array[Partition] = {
124-
val cnt = 16384 / partitionNum
168+
val hosts = scaleHostsWithPartitionNum()
125169
(0 until partitionNum).map(i => {
126170
new RedisPartition(i,
127-
new RedisConfig(redisNode._1, redisNode._2),
128-
(if (i == 0) 0 else cnt * i + 1, if (i != partitionNum - 1) cnt * (i + 1) else 16383)).asInstanceOf[Partition]
171+
new RedisConfig(hosts(i)._1, hosts(i)._2),
172+
(hosts(i)._3, hosts(i)._4)).asInstanceOf[Partition]
129173
}).toArray
130174
}
131175

0 commit comments

Comments
 (0)