@@ -390,7 +390,7 @@ impl QueryRouter {
390
390
391
391
// Likely a read-only query
392
392
Query ( query) => {
393
- match & self . pool_settings . automatic_sharding_key {
393
+ match & self . pool_settings . automatic_sharding_keys {
394
394
Some ( _) => {
395
395
// TODO: if we have multiple queries in the same message,
396
396
// we can either split them and execute them individually
@@ -571,17 +571,23 @@ impl QueryRouter {
571
571
let mut result = Vec :: new ( ) ;
572
572
let mut found = false ;
573
573
574
- let sharding_key = self
574
+ let sharding_keys = self
575
575
. pool_settings
576
- . automatic_sharding_key
576
+ . automatic_sharding_keys
577
577
. as_ref ( )
578
578
. unwrap ( )
579
- . split ( "." )
580
- . map ( |ident| Ident :: new ( ident) )
581
- . collect :: < Vec < Ident > > ( ) ;
582
-
583
- // Sharding key must be always fully qualified
584
- assert_eq ! ( sharding_key. len( ) , 2 ) ;
579
+ . iter ( )
580
+ . map ( |x| {
581
+ x. split ( "." )
582
+ . map ( |ident| Ident :: new ( ident) )
583
+ . collect :: < Vec < Ident > > ( )
584
+ } )
585
+ . collect :: < Vec < Vec < Ident > > > ( ) ;
586
+
587
+ for sharding_key in sharding_keys. iter ( ) {
588
+ // Sharding key must be always fully qualified
589
+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
590
+ }
585
591
586
592
// This parses `sharding_key = 5`. But it's technically
587
593
// legal to write `5 = sharding_key`. I don't judge the people
@@ -593,7 +599,10 @@ impl QueryRouter {
593
599
Expr :: Identifier ( ident) => {
594
600
// Only if we're dealing with only one table
595
601
// and there is no ambiguity
596
- if & ident. value == & sharding_key[ 1 ] . value {
602
+ if let Some ( sharding_key) = sharding_keys
603
+ . iter ( )
604
+ . find ( |key| & ident. value == & key[ 1 ] . value )
605
+ {
597
606
// Sharding key is unique enough, don't worry about
598
607
// table names.
599
608
if & sharding_key[ 0 ] . value == "*" {
@@ -624,8 +633,13 @@ impl QueryRouter {
624
633
// The key is fully qualified in the query,
625
634
// it will exist or Postgres will throw an error.
626
635
if idents. len ( ) == 2 {
627
- found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
628
- && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
636
+ found = sharding_keys
637
+ . iter ( )
638
+ . find ( |key| {
639
+ & key[ 0 ] . value == & idents[ 0 ] . value
640
+ && & key[ 1 ] . value == & idents[ 1 ] . value
641
+ } )
642
+ . is_some ( ) ;
629
643
}
630
644
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
631
645
}
@@ -1166,7 +1180,7 @@ mod test {
1166
1180
query_parser_enabled : true ,
1167
1181
primary_reads_enabled : false ,
1168
1182
sharding_function : ShardingFunction :: PgBigintHash ,
1169
- automatic_sharding_key : Some ( String :: from ( "test.id" ) ) ,
1183
+ automatic_sharding_keys : Some ( vec ! [ String :: from( "test.id" ) ] ) ,
1170
1184
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
1171
1185
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
1172
1186
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1241,7 +1255,7 @@ mod test {
1241
1255
query_parser_enabled : true ,
1242
1256
primary_reads_enabled : false ,
1243
1257
sharding_function : ShardingFunction :: PgBigintHash ,
1244
- automatic_sharding_key : None ,
1258
+ automatic_sharding_keys : None ,
1245
1259
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
1246
1260
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
1247
1261
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1282,14 +1296,21 @@ mod test {
1282
1296
QueryRouter :: setup ( ) ;
1283
1297
1284
1298
let mut qr = QueryRouter :: new ( ) ;
1285
- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1299
+ qr. pool_settings . automatic_sharding_keys =
1300
+ Some ( vec ! [ "data.id" . to_string( ) , "derived.data_id" . to_string( ) ] ) ;
1286
1301
qr. pool_settings . shards = 3 ;
1287
1302
1288
1303
assert ! ( qr
1289
1304
. infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) . unwrap( ) )
1290
1305
. is_ok( ) ) ;
1291
1306
assert_eq ! ( qr. shard( ) , 2 ) ;
1292
1307
1308
+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) . unwrap( ) ) . is_ok( ) ) ;
1309
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1310
+
1311
+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) . unwrap( ) ) . is_ok( ) ) ;
1312
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1313
+
1293
1314
assert ! ( qr
1294
1315
. infer(
1295
1316
& QueryRouter :: parse( & simple_query(
@@ -1300,6 +1321,16 @@ mod test {
1300
1321
. is_ok( ) ) ;
1301
1322
assert_eq ! ( qr. shard( ) , 0 ) ;
1302
1323
1324
+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query(
1325
+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1326
+ ) ) . unwrap( ) ) . is_ok( ) ) ;
1327
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1328
+
1329
+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query(
1330
+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1331
+ ) ) . unwrap( ) ) . is_ok( ) ) ;
1332
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1333
+
1303
1334
assert ! ( qr
1304
1335
. infer(
1305
1336
& QueryRouter :: parse( & simple_query(
@@ -1346,7 +1377,8 @@ mod test {
1346
1377
assert_eq ! ( qr. shard( ) , 2 ) ;
1347
1378
1348
1379
// Super unique sharding key
1349
- qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1380
+ qr. pool_settings . automatic_sharding_keys =
1381
+ Some ( vec ! [ "*.unique_enough_column_name" . to_string( ) ] ) ;
1350
1382
assert ! ( qr
1351
1383
. infer(
1352
1384
& QueryRouter :: parse( & simple_query(
@@ -1383,7 +1415,7 @@ mod test {
1383
1415
bind. put ( payload) ;
1384
1416
1385
1417
let mut qr = QueryRouter :: new ( ) ;
1386
- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1418
+ qr. pool_settings . automatic_sharding_keys = Some ( vec ! [ "data.id" . to_string( ) ] ) ;
1387
1419
qr. pool_settings . shards = 3 ;
1388
1420
1389
1421
assert ! ( qr
0 commit comments