-
Notifications
You must be signed in to change notification settings - Fork 370
/
Copy pathRedisConfig.scala
382 lines (332 loc) · 12.1 KB
/
RedisConfig.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
package com.redislabs.provider.redis
import java.net.URI
import org.apache.spark.SparkConf
import redis.clients.jedis.resps.{ClusterShardInfo, ClusterShardNodeInfo}
import redis.clients.jedis.util.{JedisClusterCRC16, JedisURIHelper, SafeEncoder}
import redis.clients.jedis.{Jedis, Protocol}
import java.nio.charset.Charset
import scala.jdk.CollectionConverters._
/**
* RedisEndpoint represents a redis connection endpoint info: host, port, auth password
* db number, timeout and ssl mode
*
* @param host the redis host or ip
* @param port the redis port
* @param user the authentication username
* @param auth the authentication password
* @param dbNum database number (should be avoided in general)
* @param ssl true to enable SSL connection. Defaults to false
*/
case class RedisEndpoint(host: String = Protocol.DEFAULT_HOST,
port: Int = Protocol.DEFAULT_PORT,
user: String = null,
auth: String = null,
dbNum: Int = Protocol.DEFAULT_DATABASE,
timeout: Int = Protocol.DEFAULT_TIMEOUT,
ssl: Boolean = false)
extends Serializable {
/**
* Constructor from spark config. set params with spark.redis.host, spark.redis.port, spark.redis.auth, spark.redis.db and spark.redis.ssl
*
* @param conf spark context config
*/
def this(conf: SparkConf) {
this(
conf.get("spark.redis.host", Protocol.DEFAULT_HOST),
conf.getInt("spark.redis.port", Protocol.DEFAULT_PORT),
conf.get("spark.redis.user", null),
conf.get("spark.redis.auth", null),
conf.getInt("spark.redis.db", Protocol.DEFAULT_DATABASE),
conf.getInt("spark.redis.timeout", Protocol.DEFAULT_TIMEOUT),
conf.getBoolean("spark.redis.ssl", false)
)
}
/**
* Constructor from spark config and parameters.
*
* @param conf spark context config
* @param parameters source specific parameters
*/
def this(conf: SparkConf, parameters: Map[String, String]) {
this(
parameters.getOrElse("host", conf.get("spark.redis.host", Protocol.DEFAULT_HOST)),
parameters.getOrElse("port", conf.get("spark.redis.port", Protocol.DEFAULT_PORT.toString)).toInt,
parameters.getOrElse("user", conf.get("spark.redis.user", null)),
parameters.getOrElse("auth", conf.get("spark.redis.auth", null)),
parameters.getOrElse("dbNum", conf.get("spark.redis.db", Protocol.DEFAULT_DATABASE.toString)).toInt,
parameters.getOrElse("timeout", conf.get("spark.redis.timeout", Protocol.DEFAULT_TIMEOUT.toString)).toInt,
parameters.getOrElse("ssl", conf.get("spark.redis.ssl", false.toString)).toBoolean)
}
/**
* Constructor with Jedis URI
*
* @param uri connection URI in the form of redis://$user:$password@$host:$port/[dbnum]. Use "rediss://" scheme for redis SSL
*/
def this(uri: URI) {
this(uri.getHost, uri.getPort, JedisURIHelper.getUser(uri), JedisURIHelper.getPassword(uri),
JedisURIHelper.getDBIndex(uri),
Protocol.DEFAULT_TIMEOUT, uri.getScheme == RedisSslScheme)
}
/**
* Constructor with Jedis URI from String
*
* @param uri connection URI in the form of redis://$user:$password@$host:$port/[dbnum]. Use "rediss://" scheme for redis SSL
*/
def this(uri: String) {
this(URI.create(uri))
}
/**
* Connect tries to open a connection to the redis endpoint,
* optionally authenticating and selecting a db
*
* @return a new Jedis instance
*/
def connect(): Jedis = {
ConnectionPool.connect(this)
}
/**
* @return config with masked password. Used for logging.
*/
def maskPassword(): RedisEndpoint = {
this.copy(auth = "")
}
}
case class RedisNode(endpoint: RedisEndpoint,
startSlot: Int,
endSlot: Int,
idx: Int,
total: Int) {
def connect(): Jedis = {
endpoint.connect()
}
}
/**
* Tuning options for read and write operations.
*/
case class ReadWriteConfig(scanCount: Int, maxPipelineSize: Int, rddWriteIteratorGroupingSize: Int)
object ReadWriteConfig {
/** maximum number of commands per pipeline **/
val MaxPipelineSizeConfKey = "spark.redis.max.pipeline.size"
val MaxPipelineSizeDefault = 100
/** count option of SCAN command **/
val ScanCountConfKey = "spark.redis.scan.count"
val ScanCountDefault = 100
/** Iterator grouping size for writing RDD **/
val RddWriteIteratorGroupingSizeKey = "spark.redis.rdd.write.iterator.grouping.size"
val RddWriteIteratorGroupingSizeDefault = 1000
val Default: ReadWriteConfig = ReadWriteConfig(ScanCountDefault, MaxPipelineSizeDefault,
RddWriteIteratorGroupingSizeDefault)
def fromSparkConf(conf: SparkConf): ReadWriteConfig = {
ReadWriteConfig(
conf.getInt(ScanCountConfKey, ScanCountDefault),
conf.getInt(MaxPipelineSizeConfKey, MaxPipelineSizeDefault),
conf.getInt(RddWriteIteratorGroupingSizeKey, RddWriteIteratorGroupingSizeDefault)
)
}
}
object RedisConfig {
/**
* create redis config from spark config
*/
def fromSparkConf(conf: SparkConf): RedisConfig = {
new RedisConfig(new RedisEndpoint(conf))
}
def fromSparkConfAndParameters(conf: SparkConf, parameters: Map[String, String]): RedisConfig = {
new RedisConfig(new RedisEndpoint(conf, parameters))
}
}
/**
* RedisConfig holds the state of the cluster nodes, and uses consistent hashing to map
* keys to nodes
*/
class RedisConfig(val initialHost: RedisEndpoint) extends Serializable {
val initialAddr: String = initialHost.host
val hosts: Array[RedisNode] = getHosts(initialHost)
val nodes: Array[RedisNode] = getNodes(initialHost)
/**
* @return initialHost's auth
*/
def getAuth: String = {
initialHost.auth
}
/**
* @return selected db number
*/
def getDB: Int = {
initialHost.dbNum
}
def getRandomNode(): RedisNode = {
val rnd = scala.util.Random.nextInt().abs % hosts.length
hosts(rnd)
}
/**
* @param sPos start slot number
* @param ePos end slot number
* @return a list of RedisNode whose slots union [sPos, ePos] is not null
*/
def getNodesBySlots(sPos: Int, ePos: Int): Array[RedisNode] = {
/* This function judges if [sPos1, ePos1] union [sPos2, ePos2] is not null */
def inter(sPos1: Int, ePos1: Int, sPos2: Int, ePos2: Int) =
if (sPos1 <= sPos2) ePos1 >= sPos2 else ePos2 >= sPos1
nodes.filter(node => inter(sPos, ePos, node.startSlot, node.endSlot)).
filter(_.idx == 0) //master only now
}
/**
* *IMPORTANT* Please remember to close after using
*
* @param key
* @return jedis that is a connection for a given key
*/
def connectionForKey(key: String): Jedis = {
getHost(key).connect()
}
/**
* *IMPORTANT* Please remember to close after using
*
* @param key
* @return jedis is a connection for a given key
*/
def connectionForKey(key: Array[Byte]): Jedis = {
getHost(key).connect()
}
/**
* @param initialHost any redis endpoint of a cluster or a single server
* @return true if the target server is in cluster mode
*/
private def clusterEnabled(initialHost: RedisEndpoint): Boolean = {
val conn = initialHost.connect()
val info = conn.info.split("\n")
val version = info.filter(_.contains("redis_version:"))(0)
val clusterEnable = info.filter(_.contains("cluster_enabled:"))
val mainVersion = version.substring(14, version.indexOf(".")).toInt
val res = mainVersion > 2 && clusterEnable.length > 0 && clusterEnable(0).contains("1")
conn.close()
res
}
/**
* @param key
* @return host whose slots should involve key
*/
def getHost(key: String): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key)
getHostBySlot(slot)
}
/**
* @param key
* @return host whose slots should involve key
*/
def getHost(key: Array[Byte]): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key)
getHostBySlot(slot)
}
private def getHostBySlot(slot: Int): RedisNode = {
hosts.filter { host =>
host.startSlot <= slot && host.endSlot >= slot
}(0)
}
/**
* @param initialHost any redis endpoint of a cluster or a single server
* @return list of host nodes
*/
private def getHosts(initialHost: RedisEndpoint): Array[RedisNode] = {
getNodes(initialHost).filter(_.idx == 0)
}
/**
* @param initialHost any redis endpoint of a single server
* @return list of nodes
*/
private def getNonClusterNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
val master = (initialHost.host, initialHost.port)
val conn = initialHost.connect()
val replinfo = conn.info("Replication").split("\n")
conn.close()
// If this node is a slave, we need to extract the slaves from its master
if (replinfo.exists(_.contains("role:slave"))) {
val host = replinfo.filter(_.contains("master_host:"))(0).trim.substring(12)
val port = replinfo.filter(_.contains("master_port:"))(0).trim.substring(12).toInt
//simply re-enter this function witht he master host/port
getNonClusterNodes(initialHost = RedisEndpoint(
host = host,
port = port,
user = initialHost.user,
auth = initialHost.auth,
dbNum = initialHost.dbNum,
ssl = initialHost.ssl
))
} else {
//this is a master - take its slaves
val slaves = replinfo.filter(x => x.contains("slave") && x.contains("online")).map(rl => {
val content = rl.substring(rl.indexOf(':') + 1).split(",")
val ip = content(0)
val port = content(1)
(ip.substring(ip.indexOf('=') + 1), port.substring(port.indexOf('=') + 1).toInt)
})
val nodes = master +: slaves
val range = nodes.length
(0 until range).map(i => {
val endpoint = RedisEndpoint(
host = nodes(i)._1,
port = nodes(i)._2,
user = initialHost.user,
auth = initialHost.auth,
dbNum = initialHost.dbNum,
timeout = initialHost.timeout,
ssl = initialHost.ssl)
RedisNode(endpoint, 0, 16383, i, range)
}).toArray
}
}
/**
* @param initialHost any redis endpoint of a cluster server
* @return list of nodes
*/
private def getClusterNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
val conn = initialHost.connect()
val res = conn.clusterShards().asScala.flatMap {
shardInfoObj: ClusterShardInfo => {
val slotInfo = shardInfoObj.getSlots
// todo: Can we have more than 1 node per ClusterShard?
val nodeInfo = shardInfoObj.getNodes.get(0)
/*
* We will get all the nodes with the slots range [sPos, ePos],
* and create RedisNode for each nodes, the total field of all
* RedisNode are the number of the nodes whose slots range is
* as above, and the idx field is just an index for each node
* which will be used for adding support for slaves and so on.
* And the idx of a master is always 0, we rely on this fact to
* filter master.
*/
(0 until (slotInfo.size)).map(i => {
val host = SafeEncoder.encode(nodeInfo.getIp.getBytes(Charset.forName("UTF8")))
val port = nodeInfo.getPort.toInt
val slotStart = slotInfo.get(i).get(0).toInt
val slotEnd = slotInfo.get(i).get(1).toInt
val endpoint = RedisEndpoint(
host = host,
port = port,
user = initialHost.user,
auth = initialHost.auth,
dbNum = initialHost.dbNum,
timeout = initialHost.timeout,
ssl = initialHost.ssl)
val role = nodeInfo.getRole
val idx = if (role == "master") 0 else i
RedisNode(endpoint, slotStart, slotEnd, idx, slotInfo.size)
})
}
}.toArray
conn.close()
res
}
/**
* @param initialHost any redis endpoint of a cluster or a single server
* @return list of nodes
*/
def getNodes(initialHost: RedisEndpoint): Array[RedisNode] = {
if (clusterEnabled(initialHost)) {
getClusterNodes(initialHost)
} else {
getNonClusterNodes(initialHost)
}
}
}