@@ -390,7 +390,7 @@ impl QueryRouter {
390390
391391 // Likely a read-only query
392392 Query ( query) => {
393- match & self . pool_settings . automatic_sharding_key {
393+ match & self . pool_settings . automatic_sharding_keys {
394394 Some ( _) => {
395395 // TODO: if we have multiple queries in the same message,
396396 // we can either split them and execute them individually
@@ -571,17 +571,23 @@ impl QueryRouter {
571571 let mut result = Vec :: new ( ) ;
572572 let mut found = false ;
573573
574- let sharding_key = self
574+ let sharding_keys = self
575575 . pool_settings
576- . automatic_sharding_key
576+ . automatic_sharding_keys
577577 . as_ref ( )
578578 . unwrap ( )
579- . split ( "." )
580- . map ( |ident| Ident :: new ( ident) )
581- . collect :: < Vec < Ident > > ( ) ;
582-
583- // Sharding key must be always fully qualified
584- assert_eq ! ( sharding_key. len( ) , 2 ) ;
579+ . iter ( )
580+ . map ( |x| {
581+ x. split ( "." )
582+ . map ( |ident| Ident :: new ( ident) )
583+ . collect :: < Vec < Ident > > ( )
584+ } )
585+ . collect :: < Vec < Vec < Ident > > > ( ) ;
586+
587+ for sharding_key in sharding_keys. iter ( ) {
588+ // Sharding key must be always fully qualified
589+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
590+ }
585591
586592 // This parses `sharding_key = 5`. But it's technically
587593 // legal to write `5 = sharding_key`. I don't judge the people
@@ -593,7 +599,10 @@ impl QueryRouter {
593599 Expr :: Identifier ( ident) => {
594600 // Only if we're dealing with only one table
595601 // and there is no ambiguity
596- if & ident. value == & sharding_key[ 1 ] . value {
602+ if let Some ( sharding_key) = sharding_keys
603+ . iter ( )
604+ . find ( |key| & ident. value == & key[ 1 ] . value )
605+ {
597606 // Sharding key is unique enough, don't worry about
598607 // table names.
599608 if & sharding_key[ 0 ] . value == "*" {
@@ -624,8 +633,13 @@ impl QueryRouter {
624633 // The key is fully qualified in the query,
625634 // it will exist or Postgres will throw an error.
626635 if idents. len ( ) == 2 {
627- found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
628- && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
636+ found = sharding_keys
637+ . iter ( )
638+ . find ( |key| {
639+ & key[ 0 ] . value == & idents[ 0 ] . value
640+ && & key[ 1 ] . value == & idents[ 1 ] . value
641+ } )
642+ . is_some ( ) ;
629643 }
630644 // TODO: key can have schema as well, e.g. public.data.id (len == 3)
631645 }
@@ -1166,7 +1180,7 @@ mod test {
11661180 query_parser_enabled : true ,
11671181 primary_reads_enabled : false ,
11681182 sharding_function : ShardingFunction :: PgBigintHash ,
1169- automatic_sharding_key : Some ( String :: from ( "test.id" ) ) ,
1183+ automatic_sharding_keys : Some ( vec ! [ String :: from( "test.id" ) ] ) ,
11701184 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
11711185 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
11721186 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1241,7 +1255,7 @@ mod test {
12411255 query_parser_enabled : true ,
12421256 primary_reads_enabled : false ,
12431257 sharding_function : ShardingFunction :: PgBigintHash ,
1244- automatic_sharding_key : None ,
1258+ automatic_sharding_keys : None ,
12451259 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
12461260 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
12471261 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1282,14 +1296,21 @@ mod test {
12821296 QueryRouter :: setup ( ) ;
12831297
12841298 let mut qr = QueryRouter :: new ( ) ;
1285- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1299+ qr. pool_settings . automatic_sharding_keys =
1300+ Some ( vec ! [ "data.id" . to_string( ) , "derived.data_id" . to_string( ) ] ) ;
12861301 qr. pool_settings . shards = 3 ;
12871302
12881303 assert ! ( qr
12891304 . infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) . unwrap( ) )
12901305 . is_ok( ) ) ;
12911306 assert_eq ! ( qr. shard( ) , 2 ) ;
12921307
1308+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) . unwrap( ) ) . is_ok( ) ) ;
1309+ assert_eq ! ( qr. shard( ) , 2 ) ;
1310+
1311+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) . unwrap( ) ) . is_ok( ) ) ;
1312+ assert_eq ! ( qr. shard( ) , 2 ) ;
1313+
12931314 assert ! ( qr
12941315 . infer(
12951316 & QueryRouter :: parse( & simple_query(
@@ -1300,6 +1321,16 @@ mod test {
13001321 . is_ok( ) ) ;
13011322 assert_eq ! ( qr. shard( ) , 0 ) ;
13021323
1324+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query(
1325+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1326+ ) ) . unwrap( ) ) . is_ok( ) ) ;
1327+ assert_eq ! ( qr. shard( ) , 0 ) ;
1328+
1329+ assert ! ( qr. infer( & QueryRouter :: parse( & simple_query(
1330+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1331+ ) ) . unwrap( ) ) . is_ok( ) ) ;
1332+ assert_eq ! ( qr. shard( ) , 0 ) ;
1333+
13031334 assert ! ( qr
13041335 . infer(
13051336 & QueryRouter :: parse( & simple_query(
@@ -1346,7 +1377,8 @@ mod test {
13461377 assert_eq ! ( qr. shard( ) , 2 ) ;
13471378
13481379 // Super unique sharding key
1349- qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1380+ qr. pool_settings . automatic_sharding_keys =
1381+ Some ( vec ! [ "*.unique_enough_column_name" . to_string( ) ] ) ;
13501382 assert ! ( qr
13511383 . infer(
13521384 & QueryRouter :: parse( & simple_query(
@@ -1383,7 +1415,7 @@ mod test {
13831415 bind. put ( payload) ;
13841416
13851417 let mut qr = QueryRouter :: new ( ) ;
1386- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1418+ qr. pool_settings . automatic_sharding_keys = Some ( vec ! [ "data.id" . to_string( ) ] ) ;
13871419 qr. pool_settings . shards = 3 ;
13881420
13891421 assert ! ( qr
0 commit comments