@@ -384,7 +384,7 @@ impl QueryRouter {
384384
385385 // Likely a read-only query
386386 Query ( query) => {
387- match & self . pool_settings . automatic_sharding_key {
387+ match & self . pool_settings . automatic_sharding_keys {
388388 Some ( _) => {
389389 // TODO: if we have multiple queries in the same message,
390390 // we can either split them and execute them individually
@@ -565,17 +565,23 @@ impl QueryRouter {
565565 let mut result = Vec :: new ( ) ;
566566 let mut found = false ;
567567
568- let sharding_key = self
568+ let sharding_keys = self
569569 . pool_settings
570- . automatic_sharding_key
570+ . automatic_sharding_keys
571571 . as_ref ( )
572572 . unwrap ( )
573- . split ( "." )
574- . map ( |ident| Ident :: new ( ident) )
575- . collect :: < Vec < Ident > > ( ) ;
576-
577- // Sharding key must be always fully qualified
578- assert_eq ! ( sharding_key. len( ) , 2 ) ;
573+ . iter ( )
574+ . map ( |x| {
575+ x. split ( "." )
576+ . map ( |ident| Ident :: new ( ident) )
577+ . collect :: < Vec < Ident > > ( )
578+ } )
579+ . collect :: < Vec < Vec < Ident > > > ( ) ;
580+
581+ for sharding_key in sharding_keys. iter ( ) {
582+ // Sharding key must be always fully qualified
583+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
584+ }
579585
580586 // This parses `sharding_key = 5`. But it's technically
581587 // legal to write `5 = sharding_key`. I don't judge the people
@@ -587,7 +593,10 @@ impl QueryRouter {
587593 Expr :: Identifier ( ident) => {
588594 // Only if we're dealing with only one table
589595 // and there is no ambiguity
590- if & ident. value == & sharding_key[ 1 ] . value {
596+ if let Some ( sharding_key) = sharding_keys
597+ . iter ( )
598+ . find ( |key| & ident. value == & key[ 1 ] . value )
599+ {
591600 // Sharding key is unique enough, don't worry about
592601 // table names.
593602 if & sharding_key[ 0 ] . value == "*" {
@@ -618,8 +627,13 @@ impl QueryRouter {
618627 // The key is fully qualified in the query,
619628 // it will exist or Postgres will throw an error.
620629 if idents. len ( ) == 2 {
621- found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
622- && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
630+ found = sharding_keys
631+ . iter ( )
632+ . find ( |key| {
633+ & key[ 0 ] . value == & idents[ 0 ] . value
634+ && & key[ 1 ] . value == & idents[ 1 ] . value
635+ } )
636+ . is_some ( ) ;
623637 }
624638 // TODO: key can have schema as well, e.g. public.data.id (len == 3)
625639 }
@@ -1103,7 +1117,7 @@ mod test {
11031117 query_parser_enabled : true ,
11041118 primary_reads_enabled : false ,
11051119 sharding_function : ShardingFunction :: PgBigintHash ,
1106- automatic_sharding_key : Some ( String :: from ( "test.id" ) ) ,
1120+ automatic_sharding_keys : Some ( vec ! [ String :: from( "test.id" ) ] ) ,
11071121 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
11081122 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
11091123 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1167,7 +1181,7 @@ mod test {
11671181 query_parser_enabled : true ,
11681182 primary_reads_enabled : false ,
11691183 sharding_function : ShardingFunction :: PgBigintHash ,
1170- automatic_sharding_key : None ,
1184+ automatic_sharding_keys : None ,
11711185 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
11721186 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
11731187 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -1205,17 +1219,26 @@ mod test {
12051219 QueryRouter :: setup ( ) ;
12061220
12071221 let mut qr = QueryRouter :: new ( ) ;
1208- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1222+ qr. pool_settings . automatic_sharding_keys =
1223+ Some ( vec ! [ "data.id" . to_string( ) , "derived.data_id" . to_string( ) ] ) ;
12091224 qr. pool_settings . shards = 3 ;
12101225
12111226 assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
12121227 assert_eq ! ( qr. shard( ) , 2 ) ;
12131228
1229+ assert ! ( qr. infer( & simple_query( "SELECT * FROM derived WHERE data_id = 5" ) ) ) ;
1230+ assert_eq ! ( qr. shard( ) , 2 ) ;
1231+
12141232 assert ! ( qr. infer( & simple_query(
12151233 "SELECT one, two, three FROM public.data WHERE id = 6"
12161234 ) ) ) ;
12171235 assert_eq ! ( qr. shard( ) , 0 ) ;
12181236
1237+ assert ! ( qr. infer( & simple_query(
1238+ "SELECT one, two, three FROM public.derived WHERE data_id = 6"
1239+ ) ) ) ;
1240+ assert_eq ! ( qr. shard( ) , 0 ) ;
1241+
12191242 assert ! ( qr. infer( & simple_query(
12201243 "SELECT * FROM data
12211244 INNER JOIN t2 ON data.id = 5
@@ -1242,7 +1265,8 @@ mod test {
12421265 assert_eq ! ( qr. shard( ) , 2 ) ;
12431266
12441267 // Super unique sharding key
1245- qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1268+ qr. pool_settings . automatic_sharding_keys =
1269+ Some ( vec ! [ "*.unique_enough_column_name" . to_string( ) ] ) ;
12461270 assert ! ( qr. infer( & simple_query(
12471271 "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
12481272 ) ) ) ;
@@ -1269,7 +1293,7 @@ mod test {
12691293 bind. put ( payload) ;
12701294
12711295 let mut qr = QueryRouter :: new ( ) ;
1272- qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1296+ qr. pool_settings . automatic_sharding_keys = Some ( vec ! [ "data.id" . to_string( ) ] ) ;
12731297 qr. pool_settings . shards = 3 ;
12741298
12751299 assert ! ( qr. infer( & simple_query( stmt) ) ) ;
0 commit comments