Skip to content

Allow for multiple automatic sharding keys #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ primary_reads_enabled = true
#
sharding_function = "pg_bigint_hash"

automatic_sharding_keys = ["data.id"]

# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
Expand Down
2 changes: 1 addition & 1 deletion pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ sharding_function = "pg_bigint_hash"
# auth_query_password = "sharding_user"

# Automatically parse this from queries and route queries to the right shard!
# automatic_sharding_key = "data.id"
# automatic_sharding_keys = ["data.id"]

# Idle timeout can be overwritten in the pool
idle_timeout = 40000
Expand Down
36 changes: 18 additions & 18 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ pub struct Pool {
pub sharding_function: ShardingFunction,

#[serde(default = "Pool::default_automatic_sharding_key")]
pub automatic_sharding_key: Option<String>,
pub automatic_sharding_keys: Option<Vec<String>>,

pub sharding_key_regex: Option<String>,
pub shard_id_regex: Option<String>,
Expand Down Expand Up @@ -571,7 +571,7 @@ impl Pool {
LoadBalancingMode::Random
}

pub fn default_automatic_sharding_key() -> Option<String> {
pub fn default_automatic_sharding_key() -> Option<Vec<String>> {
None
}

Expand Down Expand Up @@ -627,23 +627,23 @@ impl Pool {
}
}

self.automatic_sharding_key = match &self.automatic_sharding_key {
Some(key) => {
// No quotes in the key so we don't have to compare quoted
// to unquoted idents.
let key = key.replace("\"", "");

if key.split(".").count() != 2 {
error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key
);
return Err(Error::BadConfig);
match &mut self.automatic_sharding_keys {
Some(keys) => {
for key in keys {
// No quotes in the key so we don't have to compare quoted
// to unquoted idents.
let key = key.replace("\"", "");

if key.split(".").count() != 2 {
error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key
);
return Err(Error::BadConfig);
}
}

Some(key)
}
None => None,
None => (),
};

for (_, user) in &self.users {
Expand All @@ -665,7 +665,7 @@ impl Default for Pool {
query_parser_enabled: false,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
automatic_sharding_keys: None,
connect_timeout: None,
idle_timeout: None,
sharding_key_regex: None,
Expand Down
6 changes: 3 additions & 3 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub struct PoolSettings {
pub sharding_function: ShardingFunction,

// Sharding key
pub automatic_sharding_key: Option<String>,
pub automatic_sharding_keys: Option<Vec<String>>,

// Health check timeout
pub healthcheck_timeout: u64,
Expand Down Expand Up @@ -159,7 +159,7 @@ impl Default for PoolSettings {
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
automatic_sharding_keys: None,
healthcheck_delay: General::default_healthcheck_delay(),
healthcheck_timeout: General::default_healthcheck_timeout(),
ban_time: General::default_ban_time(),
Expand Down Expand Up @@ -458,7 +458,7 @@ impl ConnectionPool {
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
automatic_sharding_keys: pool_config.automatic_sharding_keys.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
Expand Down
86 changes: 69 additions & 17 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ impl QueryRouter {

// Likely a read-only query
Query(query) => {
match &self.pool_settings.automatic_sharding_key {
match &self.pool_settings.automatic_sharding_keys {
Some(_) => {
// TODO: if we have multiple queries in the same message,
// we can either split them and execute them individually
Expand Down Expand Up @@ -571,17 +571,23 @@ impl QueryRouter {
let mut result = Vec::new();
let mut found = false;

let sharding_key = self
let sharding_keys = self
.pool_settings
.automatic_sharding_key
.automatic_sharding_keys
.as_ref()
.unwrap()
.split(".")
.map(|ident| Ident::new(ident))
.collect::<Vec<Ident>>();

// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
.iter()
.map(|x| {
x.split(".")
.map(|ident| Ident::new(ident))
.collect::<Vec<Ident>>()
})
.collect::<Vec<Vec<Ident>>>();

for sharding_key in sharding_keys.iter() {
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
}

// This parses `sharding_key = 5`. But it's technically
// legal to write `5 = sharding_key`. I don't judge the people
Expand All @@ -593,7 +599,10 @@ impl QueryRouter {
Expr::Identifier(ident) => {
// Only if we're dealing with only one table
// and there is no ambiguity
if &ident.value == &sharding_key[1].value {
if let Some(sharding_key) = sharding_keys
.iter()
.find(|key| &ident.value == &key[1].value)
{
// Sharding key is unique enough, don't worry about
// table names.
if &sharding_key[0].value == "*" {
Expand Down Expand Up @@ -624,8 +633,13 @@ impl QueryRouter {
// The key is fully qualified in the query,
// it will exist or Postgres will throw an error.
if idents.len() == 2 {
found = &sharding_key[0].value == &idents[0].value
&& &sharding_key[1].value == &idents[1].value;
found = sharding_keys
Copy link
Contributor

@levkk levkk Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want to break here on found = true. If the loop continues, the second key that's not in the query will set the state as "sharding key not found", which isn't true.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you mean. The find() method breaks on the first entry for which the predicate is true. Maybe you mean that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Hmm. In that case, I'm puzzled as to why the sharding integration Ruby test failed... Any ideas?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which pgcat.toml file do the ruby tests use? Does it use the automatic sharding keys? Because if not maybe I'm handling the None wrong? I will have closer look tomorrow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really get the ruby tests running locally. I tried to simulate what the ruby test is doing in rust. The following tests succeed both on the main branch and on my branch. So it seems like the shards are inferred correctly.

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 1")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 2")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 3")));
        assert_eq!(qr.shard(), 1);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 4")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 6")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 7")));
        assert_eq!(qr.shard(), 1);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 8")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 9")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 10")));
        assert_eq!(qr.shard(), 1);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 11")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 12")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 13")));
        assert_eq!(qr.shard(), 1);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 14")));
        assert_eq!(qr.shard(), 1);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 15")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 16")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 17")));
        assert_eq!(qr.shard(), 2);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 18")));
        assert_eq!(qr.shard(), 0);

        assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 19")));
        assert_eq!(qr.shard(), 0);

Copy link
Contributor

@levkk levkk Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's curious. You can run the tests locally with Docker:

  1. bash dev/scripts/console
  2. cargo build
  3. cd tests/ruby
  4. bundle install
  5. bundle exec rspec *_spec.rb

Ruby tests are very helpful for debugging issues - they run a real app that connects to PgCat externally and provide really good integration tests with real world use cases.

.iter()
.find(|key| {
&key[0].value == &idents[0].value
&& &key[1].value == &idents[1].value
})
.is_some();
}
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
}
Expand Down Expand Up @@ -1166,7 +1180,7 @@ mod test {
query_parser_enabled: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: Some(String::from("test.id")),
automatic_sharding_keys: Some(vec![String::from("test.id")]),
healthcheck_delay: PoolSettings::default().healthcheck_delay,
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
ban_time: PoolSettings::default().ban_time,
Expand Down Expand Up @@ -1241,7 +1255,7 @@ mod test {
query_parser_enabled: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
automatic_sharding_keys: None,
healthcheck_delay: PoolSettings::default().healthcheck_delay,
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
ban_time: PoolSettings::default().ban_time,
Expand Down Expand Up @@ -1282,14 +1296,31 @@ mod test {
QueryRouter::setup();

let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.automatic_sharding_keys =
Some(vec!["data.id".to_string(), "derived.data_id".to_string()]);
qr.pool_settings.shards = 3;

assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.is_ok());
assert_eq!(qr.shard(), 2);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM derived WHERE data_id = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM derived WHERE data_id = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
Expand All @@ -1300,6 +1331,26 @@ mod test {
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.derived WHERE data_id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.derived WHERE data_id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
Expand Down Expand Up @@ -1346,7 +1397,8 @@ mod test {
assert_eq!(qr.shard(), 2);

// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
qr.pool_settings.automatic_sharding_keys =
Some(vec!["*.unique_enough_column_name".to_string()]);
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
Expand Down Expand Up @@ -1383,7 +1435,7 @@ mod test {
bind.put(payload);

let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.automatic_sharding_keys = Some(vec!["data.id".to_string()]);
qr.pool_settings.shards = 3;

assert!(qr
Expand Down