Skip to content

Commit fffdbd2

Browse files
committed
allow for multiple sharding keys
1 parent 3601130 commit fffdbd2

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

src/config.rs

+18-18
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ pub struct Pool {
457457
pub sharding_function: ShardingFunction,
458458

459459
#[serde(default = "Pool::default_automatic_sharding_key")]
460-
pub automatic_sharding_key: Option<String>,
460+
pub automatic_sharding_keys: Option<Vec<String>>,
461461

462462
pub sharding_key_regex: Option<String>,
463463
pub shard_id_regex: Option<String>,
@@ -495,7 +495,7 @@ impl Pool {
495495
LoadBalancingMode::Random
496496
}
497497

498-
pub fn default_automatic_sharding_key() -> Option<String> {
498+
pub fn default_automatic_sharding_key() -> Option<Vec<String>> {
499499
None
500500
}
501501

@@ -539,23 +539,23 @@ impl Pool {
539539
}
540540
}
541541

542-
self.automatic_sharding_key = match &self.automatic_sharding_key {
543-
Some(key) => {
544-
// No quotes in the key so we don't have to compare quoted
545-
// to unquoted idents.
546-
let key = key.replace("\"", "");
547-
548-
if key.split(".").count() != 2 {
549-
error!(
550-
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
551-
key, key
552-
);
553-
return Err(Error::BadConfig);
542+
match &mut self.automatic_sharding_keys {
543+
Some(keys) => {
544+
for key in keys {
545+
// No quotes in the key so we don't have to compare quoted
546+
// to unquoted idents.
547+
let key = key.replace("\"", "");
548+
549+
if key.split(".").count() != 2 {
550+
error!(
551+
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
552+
key, key
553+
);
554+
return Err(Error::BadConfig);
555+
}
554556
}
555-
556-
Some(key)
557557
}
558-
None => None,
558+
None => (),
559559
};
560560

561561
for (_, user) in &self.users {
@@ -577,7 +577,7 @@ impl Default for Pool {
577577
query_parser_enabled: false,
578578
primary_reads_enabled: false,
579579
sharding_function: ShardingFunction::PgBigintHash,
580-
automatic_sharding_key: None,
580+
automatic_sharding_keys: None,
581581
connect_timeout: None,
582582
idle_timeout: None,
583583
sharding_key_regex: None,

src/pool.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub struct PoolSettings {
105105
pub sharding_function: ShardingFunction,
106106

107107
// Sharding key
108-
pub automatic_sharding_key: Option<String>,
108+
pub automatic_sharding_keys: Option<Vec<String>>,
109109

110110
// Health check timeout
111111
pub healthcheck_timeout: u64,
@@ -142,7 +142,7 @@ impl Default for PoolSettings {
142142
query_parser_enabled: false,
143143
primary_reads_enabled: true,
144144
sharding_function: ShardingFunction::PgBigintHash,
145-
automatic_sharding_key: None,
145+
automatic_sharding_keys: None,
146146
healthcheck_delay: General::default_healthcheck_delay(),
147147
healthcheck_timeout: General::default_healthcheck_timeout(),
148148
ban_time: General::default_ban_time(),
@@ -421,7 +421,7 @@ impl ConnectionPool {
421421
query_parser_enabled: pool_config.query_parser_enabled,
422422
primary_reads_enabled: pool_config.primary_reads_enabled,
423423
sharding_function: pool_config.sharding_function,
424-
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
424+
automatic_sharding_keys: pool_config.automatic_sharding_keys.clone(),
425425
healthcheck_delay: config.general.healthcheck_delay,
426426
healthcheck_timeout: config.general.healthcheck_timeout,
427427
ban_time: config.general.ban_time,

src/query_router.rs

+41-17
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ impl QueryRouter {
384384

385385
// Likely a read-only query
386386
Query(query) => {
387-
match &self.pool_settings.automatic_sharding_key {
387+
match &self.pool_settings.automatic_sharding_keys {
388388
Some(_) => {
389389
// TODO: if we have multiple queries in the same message,
390390
// we can either split them and execute them individually
@@ -565,17 +565,23 @@ impl QueryRouter {
565565
let mut result = Vec::new();
566566
let mut found = false;
567567

568-
let sharding_key = self
568+
let sharding_keys = self
569569
.pool_settings
570-
.automatic_sharding_key
570+
.automatic_sharding_keys
571571
.as_ref()
572572
.unwrap()
573-
.split(".")
574-
.map(|ident| Ident::new(ident))
575-
.collect::<Vec<Ident>>();
576-
577-
// Sharding key must be always fully qualified
578-
assert_eq!(sharding_key.len(), 2);
573+
.iter()
574+
.map(|x| {
575+
x.split(".")
576+
.map(|ident| Ident::new(ident))
577+
.collect::<Vec<Ident>>()
578+
})
579+
.collect::<Vec<Vec<Ident>>>();
580+
581+
for sharding_key in sharding_keys.iter() {
582+
// Sharding key must be always fully qualified
583+
assert_eq!(sharding_key.len(), 2);
584+
}
579585

580586
// This parses `sharding_key = 5`. But it's technically
581587
// legal to write `5 = sharding_key`. I don't judge the people
@@ -587,7 +593,10 @@ impl QueryRouter {
587593
Expr::Identifier(ident) => {
588594
// Only if we're dealing with only one table
589595
// and there is no ambiguity
590-
if &ident.value == &sharding_key[1].value {
596+
if let Some(sharding_key) = sharding_keys
597+
.iter()
598+
.find(|key| &ident.value == &key[1].value)
599+
{
591600
// Sharding key is unique enough, don't worry about
592601
// table names.
593602
if &sharding_key[0].value == "*" {
@@ -618,8 +627,13 @@ impl QueryRouter {
618627
// The key is fully qualified in the query,
619628
// it will exist or Postgres will throw an error.
620629
if idents.len() == 2 {
621-
found = &sharding_key[0].value == &idents[0].value
622-
&& &sharding_key[1].value == &idents[1].value;
630+
found = sharding_keys
631+
.iter()
632+
.find(|key| {
633+
&key[0].value == &idents[0].value
634+
&& &key[1].value == &idents[1].value
635+
})
636+
.is_some();
623637
}
624638
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
625639
}
@@ -1103,7 +1117,7 @@ mod test {
11031117
query_parser_enabled: true,
11041118
primary_reads_enabled: false,
11051119
sharding_function: ShardingFunction::PgBigintHash,
1106-
automatic_sharding_key: Some(String::from("test.id")),
1120+
automatic_sharding_keys: Some(vec![String::from("test.id")]),
11071121
healthcheck_delay: PoolSettings::default().healthcheck_delay,
11081122
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
11091123
ban_time: PoolSettings::default().ban_time,
@@ -1167,7 +1181,7 @@ mod test {
11671181
query_parser_enabled: true,
11681182
primary_reads_enabled: false,
11691183
sharding_function: ShardingFunction::PgBigintHash,
1170-
automatic_sharding_key: None,
1184+
automatic_sharding_keys: None,
11711185
healthcheck_delay: PoolSettings::default().healthcheck_delay,
11721186
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
11731187
ban_time: PoolSettings::default().ban_time,
@@ -1205,17 +1219,26 @@ mod test {
12051219
QueryRouter::setup();
12061220

12071221
let mut qr = QueryRouter::new();
1208-
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
1222+
qr.pool_settings.automatic_sharding_keys =
1223+
Some(vec!["data.id".to_string(), "derived.data_id".to_string()]);
12091224
qr.pool_settings.shards = 3;
12101225

12111226
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
12121227
assert_eq!(qr.shard(), 2);
12131228

1229+
assert!(qr.infer(&simple_query("SELECT * FROM derived WHERE data_id = 5")));
1230+
assert_eq!(qr.shard(), 2);
1231+
12141232
assert!(qr.infer(&simple_query(
12151233
"SELECT one, two, three FROM public.data WHERE id = 6"
12161234
)));
12171235
assert_eq!(qr.shard(), 0);
12181236

1237+
assert!(qr.infer(&simple_query(
1238+
"SELECT one, two, three FROM public.derived WHERE data_id = 6"
1239+
)));
1240+
assert_eq!(qr.shard(), 0);
1241+
12191242
assert!(qr.infer(&simple_query(
12201243
"SELECT * FROM data
12211244
INNER JOIN t2 ON data.id = 5
@@ -1242,7 +1265,8 @@ mod test {
12421265
assert_eq!(qr.shard(), 2);
12431266

12441267
// Super unique sharding key
1245-
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
1268+
qr.pool_settings.automatic_sharding_keys =
1269+
Some(vec!["*.unique_enough_column_name".to_string()]);
12461270
assert!(qr.infer(&simple_query(
12471271
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
12481272
)));
@@ -1269,7 +1293,7 @@ mod test {
12691293
bind.put(payload);
12701294

12711295
let mut qr = QueryRouter::new();
1272-
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
1296+
qr.pool_settings.automatic_sharding_keys = Some(vec!["data.id".to_string()]);
12731297
qr.pool_settings.shards = 3;
12741298

12751299
assert!(qr.infer(&simple_query(stmt)));

0 commit comments

Comments
 (0)