Skip to content

Commit d5b8579

Browse files
committed
allow for multiple sharding keys
1 parent 4b78af9 commit d5b8579

File tree

3 files changed

+70
-38
lines changed

3 files changed

+70
-38
lines changed

src/config.rs

+18-18
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ pub struct Pool {
522522
pub sharding_function: ShardingFunction,
523523

524524
#[serde(default = "Pool::default_automatic_sharding_key")]
525-
pub automatic_sharding_key: Option<String>,
525+
pub automatic_sharding_keys: Option<Vec<String>>,
526526

527527
pub sharding_key_regex: Option<String>,
528528
pub shard_id_regex: Option<String>,
@@ -564,7 +564,7 @@ impl Pool {
564564
LoadBalancingMode::Random
565565
}
566566

567-
pub fn default_automatic_sharding_key() -> Option<String> {
567+
pub fn default_automatic_sharding_key() -> Option<Vec<String>> {
568568
None
569569
}
570570

@@ -620,23 +620,23 @@ impl Pool {
620620
}
621621
}
622622

623-
self.automatic_sharding_key = match &self.automatic_sharding_key {
624-
Some(key) => {
625-
// No quotes in the key so we don't have to compare quoted
626-
// to unquoted idents.
627-
let key = key.replace("\"", "");
628-
629-
if key.split(".").count() != 2 {
630-
error!(
631-
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
632-
key, key
633-
);
634-
return Err(Error::BadConfig);
623+
match &mut self.automatic_sharding_keys {
624+
Some(keys) => {
625+
for key in keys {
626+
// No quotes in the key so we don't have to compare quoted
627+
// to unquoted idents.
628+
let key = key.replace("\"", "");
629+
630+
if key.split(".").count() != 2 {
631+
error!(
632+
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
633+
key, key
634+
);
635+
return Err(Error::BadConfig);
636+
}
635637
}
636-
637-
Some(key)
638638
}
639-
None => None,
639+
None => (),
640640
};
641641

642642
for (_, user) in &self.users {
@@ -658,7 +658,7 @@ impl Default for Pool {
658658
query_parser_enabled: false,
659659
primary_reads_enabled: false,
660660
sharding_function: ShardingFunction::PgBigintHash,
661-
automatic_sharding_key: None,
661+
automatic_sharding_keys: None,
662662
connect_timeout: None,
663663
idle_timeout: None,
664664
sharding_key_regex: None,

src/pool.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ pub struct PoolSettings {
118118
pub sharding_function: ShardingFunction,
119119

120120
// Sharding key
121-
pub automatic_sharding_key: Option<String>,
121+
pub automatic_sharding_keys: Option<Vec<String>>,
122122

123123
// Health check timeout
124124
pub healthcheck_timeout: u64,
@@ -159,7 +159,7 @@ impl Default for PoolSettings {
159159
query_parser_enabled: false,
160160
primary_reads_enabled: true,
161161
sharding_function: ShardingFunction::PgBigintHash,
162-
automatic_sharding_key: None,
162+
automatic_sharding_keys: None,
163163
healthcheck_delay: General::default_healthcheck_delay(),
164164
healthcheck_timeout: General::default_healthcheck_timeout(),
165165
ban_time: General::default_ban_time(),
@@ -458,7 +458,7 @@ impl ConnectionPool {
458458
query_parser_enabled: pool_config.query_parser_enabled,
459459
primary_reads_enabled: pool_config.primary_reads_enabled,
460460
sharding_function: pool_config.sharding_function,
461-
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
461+
automatic_sharding_keys: pool_config.automatic_sharding_keys.clone(),
462462
healthcheck_delay: config.general.healthcheck_delay,
463463
healthcheck_timeout: config.general.healthcheck_timeout,
464464
ban_time: config.general.ban_time,

src/query_router.rs

+49-17
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ impl QueryRouter {
390390

391391
// Likely a read-only query
392392
Query(query) => {
393-
match &self.pool_settings.automatic_sharding_key {
393+
match &self.pool_settings.automatic_sharding_keys {
394394
Some(_) => {
395395
// TODO: if we have multiple queries in the same message,
396396
// we can either split them and execute them individually
@@ -571,17 +571,23 @@ impl QueryRouter {
571571
let mut result = Vec::new();
572572
let mut found = false;
573573

574-
let sharding_key = self
574+
let sharding_keys = self
575575
.pool_settings
576-
.automatic_sharding_key
576+
.automatic_sharding_keys
577577
.as_ref()
578578
.unwrap()
579-
.split(".")
580-
.map(|ident| Ident::new(ident))
581-
.collect::<Vec<Ident>>();
582-
583-
// Sharding key must be always fully qualified
584-
assert_eq!(sharding_key.len(), 2);
579+
.iter()
580+
.map(|x| {
581+
x.split(".")
582+
.map(|ident| Ident::new(ident))
583+
.collect::<Vec<Ident>>()
584+
})
585+
.collect::<Vec<Vec<Ident>>>();
586+
587+
for sharding_key in sharding_keys.iter() {
588+
// Sharding key must be always fully qualified
589+
assert_eq!(sharding_key.len(), 2);
590+
}
585591

586592
// This parses `sharding_key = 5`. But it's technically
587593
// legal to write `5 = sharding_key`. I don't judge the people
@@ -593,7 +599,10 @@ impl QueryRouter {
593599
Expr::Identifier(ident) => {
594600
// Only if we're dealing with only one table
595601
// and there is no ambiguity
596-
if &ident.value == &sharding_key[1].value {
602+
if let Some(sharding_key) = sharding_keys
603+
.iter()
604+
.find(|key| &ident.value == &key[1].value)
605+
{
597606
// Sharding key is unique enough, don't worry about
598607
// table names.
599608
if &sharding_key[0].value == "*" {
@@ -624,8 +633,13 @@ impl QueryRouter {
624633
// The key is fully qualified in the query,
625634
// it will exist or Postgres will throw an error.
626635
if idents.len() == 2 {
627-
found = &sharding_key[0].value == &idents[0].value
628-
&& &sharding_key[1].value == &idents[1].value;
636+
found = sharding_keys
637+
.iter()
638+
.find(|key| {
639+
&key[0].value == &idents[0].value
640+
&& &key[1].value == &idents[1].value
641+
})
642+
.is_some();
629643
}
630644
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
631645
}
@@ -1166,7 +1180,7 @@ mod test {
11661180
query_parser_enabled: true,
11671181
primary_reads_enabled: false,
11681182
sharding_function: ShardingFunction::PgBigintHash,
1169-
automatic_sharding_key: Some(String::from("test.id")),
1183+
automatic_sharding_keys: Some(vec![String::from("test.id")]),
11701184
healthcheck_delay: PoolSettings::default().healthcheck_delay,
11711185
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
11721186
ban_time: PoolSettings::default().ban_time,
@@ -1241,7 +1255,7 @@ mod test {
12411255
query_parser_enabled: true,
12421256
primary_reads_enabled: false,
12431257
sharding_function: ShardingFunction::PgBigintHash,
1244-
automatic_sharding_key: None,
1258+
automatic_sharding_keys: None,
12451259
healthcheck_delay: PoolSettings::default().healthcheck_delay,
12461260
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
12471261
ban_time: PoolSettings::default().ban_time,
@@ -1282,14 +1296,21 @@ mod test {
12821296
QueryRouter::setup();
12831297

12841298
let mut qr = QueryRouter::new();
1285-
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
1299+
qr.pool_settings.automatic_sharding_keys =
1300+
Some(vec!["data.id".to_string(), "derived.data_id".to_string()]);
12861301
qr.pool_settings.shards = 3;
12871302

12881303
assert!(qr
12891304
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
12901305
.is_ok());
12911306
assert_eq!(qr.shard(), 2);
12921307

1308+
assert!(qr.infer(&QueryRouter::parse(&simple_query("SELECT * FROM derived WHERE data_id = 5")).unwrap()).is_ok());
1309+
assert_eq!(qr.shard(), 2);
1310+
1311+
assert!(qr.infer(&QueryRouter::parse(&simple_query("SELECT * FROM derived WHERE data_id = 5")).unwrap()).is_ok());
1312+
assert_eq!(qr.shard(), 2);
1313+
12931314
assert!(qr
12941315
.infer(
12951316
&QueryRouter::parse(&simple_query(
@@ -1300,6 +1321,16 @@ mod test {
13001321
.is_ok());
13011322
assert_eq!(qr.shard(), 0);
13021323

1324+
assert!(qr.infer(&QueryRouter::parse(&simple_query(
1325+
"SELECT one, two, three FROM public.derived WHERE data_id = 6"
1326+
)).unwrap()).is_ok());
1327+
assert_eq!(qr.shard(), 0);
1328+
1329+
assert!(qr.infer(&QueryRouter::parse(&simple_query(
1330+
"SELECT one, two, three FROM public.derived WHERE data_id = 6"
1331+
)).unwrap()).is_ok());
1332+
assert_eq!(qr.shard(), 0);
1333+
13031334
assert!(qr
13041335
.infer(
13051336
&QueryRouter::parse(&simple_query(
@@ -1346,7 +1377,8 @@ mod test {
13461377
assert_eq!(qr.shard(), 2);
13471378

13481379
// Super unique sharding key
1349-
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
1380+
qr.pool_settings.automatic_sharding_keys =
1381+
Some(vec!["*.unique_enough_column_name".to_string()]);
13501382
assert!(qr
13511383
.infer(
13521384
&QueryRouter::parse(&simple_query(
@@ -1383,7 +1415,7 @@ mod test {
13831415
bind.put(payload);
13841416

13851417
let mut qr = QueryRouter::new();
1386-
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
1418+
qr.pool_settings.automatic_sharding_keys = Some(vec!["data.id".to_string()]);
13871419
qr.pool_settings.shards = 3;
13881420

13891421
assert!(qr

0 commit comments

Comments
 (0)