@@ -384,7 +384,7 @@ impl QueryRouter {
384
384
385
385
// Likely a read-only query
386
386
Query ( query) => {
387
- match & self . pool_settings . automatic_sharding_key {
387
+ match & self . pool_settings . automatic_sharding_keys {
388
388
Some ( _) => {
389
389
// TODO: if we have multiple queries in the same message,
390
390
// we can either split them and execute them individually
@@ -565,17 +565,23 @@ impl QueryRouter {
565
565
let mut result = Vec :: new ( ) ;
566
566
let mut found = false ;
567
567
568
- let sharding_key = self
568
+ let sharding_keys = self
569
569
. pool_settings
570
- . automatic_sharding_key
570
+ . automatic_sharding_keys
571
571
. as_ref ( )
572
572
. unwrap ( )
573
- . split ( "." )
574
- . map ( |ident| Ident :: new ( ident) )
575
- . collect :: < Vec < Ident > > ( ) ;
576
-
577
- // Sharding key must be always fully qualified
578
- assert_eq ! ( sharding_key. len( ) , 2 ) ;
573
+ . iter ( )
574
+ . map ( |x| {
575
+ x. split ( "." )
576
+ . map ( |ident| Ident :: new ( ident) )
577
+ . collect :: < Vec < Ident > > ( )
578
+ } )
579
+ . collect :: < Vec < Vec < Ident > > > ( ) ;
580
+
581
+ for sharding_key in sharding_keys. iter ( ) {
582
+ // Sharding key must be always fully qualified
583
+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
584
+ }
579
585
580
586
// This parses `sharding_key = 5`. But it's technically
581
587
// legal to write `5 = sharding_key`. I don't judge the people
@@ -587,7 +593,10 @@ impl QueryRouter {
587
593
Expr :: Identifier ( ident) => {
588
594
// Only if we're dealing with only one table
589
595
// and there is no ambiguity
590
- if & ident. value == & sharding_key[ 1 ] . value {
596
+ if let Some ( sharding_key) = sharding_keys
597
+ . iter ( )
598
+ . find ( |key| & ident. value == & key[ 1 ] . value )
599
+ {
591
600
// Sharding key is unique enough, don't worry about
592
601
// table names.
593
602
if & sharding_key[ 0 ] . value == "*" {
@@ -618,8 +627,13 @@ impl QueryRouter {
618
627
// The key is fully qualified in the query,
619
628
// it will exist or Postgres will throw an error.
620
629
if idents. len ( ) == 2 {
621
- found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
622
- && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
630
+ found = sharding_keys
631
+ . iter ( )
632
+ . find ( |key| {
633
+ & key[ 0 ] . value == & idents[ 0 ] . value
634
+ && & key[ 1 ] . value == & idents[ 1 ] . value
635
+ } )
636
+ . is_some ( ) ;
623
637
}
624
638
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
625
639
}
@@ -1103,7 +1117,7 @@ mod test {
1103
1117
query_parser_enabled : true ,
1104
1118
primary_reads_enabled : false ,
1105
1119
sharding_function : ShardingFunction :: PgBigintHash ,
1106
- automatic_sharding_key : Some ( String :: from ( "test.id" ) ) ,
1120
+ automatic_sharding_keys : Some ( vec ! [ String :: from( "test.id" ) ] ) ,
1107
1121
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
1108
1122
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
1109
1123
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1167,7 +1181,7 @@ mod test {
1167
1181
query_parser_enabled : true ,
1168
1182
primary_reads_enabled : false ,
1169
1183
sharding_function : ShardingFunction :: PgBigintHash ,
1170
- automatic_sharding_key : None ,
1184
+ automatic_sharding_keys : None ,
1171
1185
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
1172
1186
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
1173
1187
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1205,17 +1219,26 @@ mod test {
1205
1219
QueryRouter :: setup ( ) ;
1206
1220
1207
1221
let mut qr = QueryRouter :: new ( ) ;
1208
- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1222
+ qr. pool_settings . automatic_sharding_keys =
1223
+ Some ( vec ! [ "data.id" . to_string( ) , "derived.data_id" . to_string( ) ] ) ;
1209
1224
qr. pool_settings . shards = 3 ;
1210
1225
1211
1226
assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
1212
1227
assert_eq ! ( qr. shard( ) , 2 ) ;
1213
1228
1229
+ assert ! ( qr. infer( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) ) ;
1230
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1231
+
1214
1232
assert ! ( qr. infer( & simple_query(
1215
1233
"SELECT one, two, three FROM public.data WHERE id = 6"
1216
1234
) ) ) ;
1217
1235
assert_eq ! ( qr. shard( ) , 0 ) ;
1218
1236
1237
+ assert ! ( qr. infer( & simple_query(
1238
+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1239
+ ) ) ) ;
1240
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1241
+
1219
1242
assert ! ( qr. infer( & simple_query(
1220
1243
"SELECT * FROM data
1221
1244
INNER JOIN t2 ON data.id = 5
@@ -1242,7 +1265,8 @@ mod test {
1242
1265
assert_eq ! ( qr. shard( ) , 2 ) ;
1243
1266
1244
1267
// Super unique sharding key
1245
- qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1268
+ qr. pool_settings . automatic_sharding_keys =
1269
+ Some ( vec ! [ "*.unique_enough_column_name" . to_string( ) ] ) ;
1246
1270
assert ! ( qr. infer( & simple_query(
1247
1271
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1248
1272
) ) ) ;
@@ -1269,7 +1293,7 @@ mod test {
1269
1293
bind. put ( payload) ;
1270
1294
1271
1295
let mut qr = QueryRouter :: new ( ) ;
1272
- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1296
+ qr. pool_settings . automatic_sharding_keys = Some ( vec ! [ "data.id" . to_string( ) ] ) ;
1273
1297
qr. pool_settings . shards = 3 ;
1274
1298
1275
1299
assert ! ( qr. infer( & simple_query( stmt) ) ) ;
0 commit comments