From 48bb6ebeefbb37e1cc0eb0b7cfa8f624d42a43bf Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 28 Aug 2022 17:23:28 -0700 Subject: [PATCH 1/2] Support settings custom search path --- .circleci/pgcat.toml | 1 + src/client.rs | 8 +++++--- src/config.rs | 6 ++++++ src/messages.rs | 18 +++++++++++++++++- src/pool.rs | 1 + src/server.rs | 8 +++++++- 6 files changed, 37 insertions(+), 5 deletions(-) diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 56aa1ddc..91742426 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -108,6 +108,7 @@ servers = [ ] # Database name (e.g. "postgres") database = "shard0" +search_path = "\"$user\",public" [pools.sharded_db.shards.1] servers = [ diff --git a/src/client.rs b/src/client.rs index 419448fb..e79bce14 100644 --- a/src/client.rs +++ b/src/client.rs @@ -354,6 +354,8 @@ where let stats = get_reporter(); let parameters = parse_startup(bytes.clone())?; + info!("params: {:?}", parameters); + // These two parameters are mandatory by the protocol. let pool_name = match parameters.get("database") { Some(db) => db, @@ -644,8 +646,8 @@ where // SET SHARD TO Some((Command::SetShard, _)) => { - // Selected shard is not configured. - if query_router.shard() >= pool.shards() { + let shard = query_router.shard(); + if shard >= pool.shards() { // Set the shard back to what it was. query_router.set_shard(current_shard); @@ -653,7 +655,7 @@ where &mut self.write, &format!( "shard {} is more than configured {}, staying on shard {}", - query_router.shard(), + shard, pool.shards(), current_shard, ), diff --git a/src/config.rs b/src/config.rs index 5c122611..7f21e937 100644 --- a/src/config.rs +++ b/src/config.rs @@ -72,6 +72,9 @@ pub struct Address { /// The name of the Postgres database. pub database: String, + /// Default search_path. + pub search_path: Option, + /// Server role: replica, primary. pub role: Role, @@ -98,6 +101,7 @@ impl Default for Address { address_index: 0, replica_number: 0, database: String::from("database"), + search_path: None, role: Role::Replica, username: String::from("username"), pool_name: String::from("pool_name"), @@ -206,6 +210,7 @@ impl Default for Pool { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Shard { pub database: String, + pub search_path: Option, pub servers: Vec<(String, u16, String)>, } @@ -213,6 +218,7 @@ impl Default for Shard { fn default() -> Shard { Shard { servers: vec![(String::from("localhost"), 5432, String::from("primary"))], + search_path: None, database: String::from("postgres"), } } diff --git a/src/messages.rs b/src/messages.rs index 113e1ed5..867ab7b4 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -111,7 +111,12 @@ where /// Send the startup packet the server. We're pretending we're a Pg client. /// This tells the server which user we are and what database we want. -pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { +pub async fn startup( + stream: &mut TcpStream, + user: &str, + database: &str, + search_path: Option<&String>, +) -> Result<(), Error> { let mut bytes = BytesMut::with_capacity(25); bytes.put_i32(196608); // Protocol number @@ -125,6 +130,17 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu bytes.put(&b"database\0"[..]); bytes.put_slice(&database.as_bytes()); bytes.put_u8(0); + + // search_path + match search_path { + Some(search_path) => { + bytes.put(&b"options\0"[..]); + bytes.put_slice(&format!("-c search_path={}", search_path).as_bytes()); + bytes.put_u8(0); + } + None => (), + }; + bytes.put_u8(0); // Null terminator let len = bytes.len() as i32 + 4i32; diff --git a/src/pool.rs b/src/pool.rs index 99cccaf1..a576e08f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -155,6 +155,7 @@ impl ConnectionPool { let address = Address { id: address_id, database: shard.database.clone(), + search_path: shard.search_path.clone(), host: server.0.clone(), port: server.1 as u16, role: role, diff --git a/src/server.rs b/src/server.rs index 3134a65d..f5147ade 100644 --- a/src/server.rs +++ b/src/server.rs @@ -86,7 +86,13 @@ impl Server { trace!("Sending StartupMessage"); // StartupMessage - startup(&mut stream, &user.username, database).await?; + startup( + &mut stream, + &user.username, + database, + address.search_path.as_ref(), + ) + .await?; let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; From 5872354c3e19ce001eddf8a389c2046523eb2578 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 28 Aug 2022 17:29:13 -0700 Subject: [PATCH 2/2] remove debug --- src/client.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index e79bce14..b8c82a9c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -354,8 +354,6 @@ where let stats = get_reporter(); let parameters = parse_startup(bytes.clone())?; - info!("params: {:?}", parameters); - // These two parameters are mandatory by the protocol. let pool_name = match parameters.get("database") { Some(db) => db,